from pyprojroot import here
from isssm.laplace_approximation import posterior_mode
from isssm.laplace_approximation import posterior_mode
from isssm.importance_sampling import ess_pct
import pandas as pd
from isssm.importance_sampling import pgssm_importance_sampling
from isssm.ce_method import log_weight_cem, simulate_cem
from jax import vmap
from functools import partial
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
modified_efficient_importance_sampling as MEIS,
)
from isssm.ce_method import cross_entropy_method as CEM
from isssm.pgssm import simulate_pgssm
import jax.random as jrn
import jax.numpy as jnp
import jax
from isssm.typing import PGSSM
from tensorflow_probability.substrates.jax.distributions import Poisson
from tqdm.notebook import tqdmComparison of EIS and the CEM for SSMs
Simplified version of regional model in Chapter 4.1, keeping only \(\log I_t\) and \(\log \rho_t\) in the states.
- States \(X_t = \left(\log I_{t}, \log \rho_{t + 1}\right)\)
- Observations \(Y_t | X_t \sim \operatorname{Pois} \left( \exp \log I_{t}\right)\)
Varying \(n = 10, 100, 1000\). Initialize \(\log \rho_0 = 0\) with small variance and \(\log I_0 = \log 1000\) with small variance as well.
Let \(\sigma^2_\rho = \frac{1}{n}0.05\), s.t. \(\operatorname{Var} (\log \rho_{n +1}) = 0.05\) and approx. \(\mathbf P(\log \rho_{n + 1} \in [-0.1, 0.1]) \geq 0.95\), so approx. \(\rho_{n +1} \in [0.9, 1.1]\), ensuring stabilitiy of infections counts (don’t go to \(0\) or \(\infty\)).
jax.config.update("jax_enable_x64", True)# parameters
N_samples = 10_000
N_ef = 1_000
N_iter = 100
M = 100
K = 10
K_ef = 100parameters_tex = f"""
We set the number of iterations of the \\gls{{cem}} and \\gls{{eis}} to ${N_iter}$, which, in our experience, suffices to determine whether the numerical scheme converges or diverges. We use $M={M}$ samples to obtain the covariance matrices. For both methods we use $N = {N_samples}$ samples for estimation.
The above procedure generates a single asymptotic variance ratios for a fixed number of time points $n$. As the performance of importance sampling is likely influenced by the sample $y$, we repeat the simulation $K={K}$ times to obtain $K$ different outcomes.
"""
with open(
here("chapters/03_state_space_models/03_08_comparison_ssm_parameters_var.tex"), "w"
) as f:
f.write(parameters_tex)text_parameters_repeat = f"""Again, we repeat this procedure $K={K_ef}$ times for varying levels of $n$ and use ${N_iter}$ iterations for all three methods, as well as ${N_samples}$ samples to estimate the optimal proposal."""
with open(
here("chapters/03_state_space_models/03_08_comparison_ssm_parameters_ef.tex"), "w"
) as f:
f.write(text_parameters_repeat)The Kernel crashed while executing code in the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.
def _model(n, I0):
np1 = n + 1
s2_rho = 0.05 / n if n > 1 else 1
m = 2
p = 1
l = 1
# states
u = jnp.zeros((np1, m))
u = u.at[0, 0].set(jnp.log(I0))
A = jnp.broadcast_to(jnp.array([[1.0, 1.0], [0.0, 1.0]]), (n, m, m))
D = jnp.broadcast_to(jnp.eye(m)[:, 1:2], (n, m, l)) # only update rho
Sigma0 = jnp.array([[1.0, 0.0], [0.0, 0.1]])
Sigma = jnp.broadcast_to(s2_rho * jnp.eye(1), (n, l, l))
# observations
B = jnp.broadcast_to(jnp.eye(m)[:1], (np1, p, m))
v = jnp.zeros((np1, p))
def poisson_obs(s, xi):
return Poisson(log_rate=s)
dist = poisson_obs
xi = jnp.empty((np1, p, 1))
return PGSSM(u, A, D, Sigma0, Sigma, v, B, dist, xi)def determine_efficiency_factor(n, key):
pgssm = _model(n, I0=1000)
key, subkey = jrn.split(key)
_, (Y,) = simulate_pgssm(pgssm, 1, subkey)
key, sk_meis, sk_cem = jrn.split(key, 3)
prop_la, _ = LA(Y, pgssm, N_iter)
prop_meis, _ = MEIS(Y, pgssm, prop_la.z, prop_la.Omega, N_iter, N_samples, sk_meis)
prop_cem, lw_cem = CEM(pgssm, Y, N_samples, sk_cem, N_iter)
key, sk_la, sk_meis, sk_cem = jrn.split(key, 4)
_, lw_la = pgssm_importance_sampling(
Y, pgssm, prop_la.z, prop_la.Omega, N_ef, sk_la
)
_, lw_meis = pgssm_importance_sampling(
Y, pgssm, prop_meis.z, prop_meis.Omega, N_ef, sk_meis
)
# lw_cem = vmap(partial(log_weight_cem, y=Y, model=pgssm, proposal=prop_cem))(
# simulate_cem(prop_cem, N_samples, sk_cem)
# )
result = pd.Series(
{
"n": n,
"N_samples": N_samples,
"N_iter": N_iter,
"EF_LA": ess_pct(lw_la),
"EF_MEIS": ess_pct(lw_meis),
"EF_CEM": ess_pct(lw_cem),
}
)
return resultkey = jrn.PRNGKey(140235293)
ns_ef = jnp.repeat(jnp.array([1, 10, 20, 50, 100]), K_ef)
key, *keys_ef = jrn.split(key, len(ns_ef) + 1)results_list = []
for n, k in tqdm(zip(ns_ef, keys_ef), total=len(ns_ef)):
results_list.append(determine_efficiency_factor(n, k))
results_ef = pd.DataFrame(results_list)
results_ef.to_csv(here("data/figures/ef_meis_cem_ssms.csv"), index=False)--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[54], line 4 1 results_list = [] 3 for n, k in tqdm(zip(ns_ef, keys_ef), total=len(ns_ef)): ----> 4 results_list.append(determine_efficiency_factor(n, k)) 5 results_ef = pd.DataFrame(results_list) 7 results_ef.to_csv(here("data/figures/ef_meis_cem_ssms.csv"), index=False) Cell In[52], line 8, in determine_efficiency_factor(n, key) 5 _, (Y,) = simulate_pgssm(pgssm, 1, subkey) 7 key, sk_meis, sk_cem = jrn.split(key, 3) ----> 8 prop_la, _ = LA(Y, pgssm, N_iter) 9 prop_meis, _ = MEIS(Y, pgssm, prop_la.z, prop_la.Omega, N_iter, N_samples, sk_meis) 10 prop_cem, lw_cem = CEM(pgssm, Y, N_samples, sk_cem, N_iter) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/laplace_approximation.py:62, in laplace_approximation(y, model, n_iter, log_lik, d_log_lik, dd_log_lik, eps, link) 59 u, A, D, Sigma0, Sigma, v, B, dist, xi = model 60 np1, p, m = B.shape ---> 62 s_init = vvmap(partial(_initial_guess, dist=dist, link=link))(xi, y) 64 def default_log_lik(s_ti, xi_ti, y_ti): 65 return dist(s_ti, xi_ti).log_prob(y_ti).sum() [... skipping hidden 1 frame] File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/api.py:1127, in vmap.<locals>.vmap_f(*args, **kwargs) 1124 try: 1125 axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name, 1126 explicit_mesh_axis) -> 1127 out_flat = batching.batch( 1128 flat_fun, axis_data, in_axes_flat, 1129 lambda: flatten_axes("vmap out_axes", out_tree(), out_axes) 1130 ).call_wrapped(*args_flat) 1131 except batching.SpecMatchError as e: 1132 out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py:211, in WrappedFun.call_wrapped(self, *args, **kwargs) 209 def call_wrapped(self, *args, **kwargs): 210 """Calls the transformed function""" --> 211 return self.f_transformed(*args, **kwargs) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:609, in _batch_outer(f, axis_data, in_dims, *in_vals) 607 tag = TraceTag() 608 with source_info_util.transform_name_stack('vmap'): --> 609 outs, trace = f(tag, in_dims, *in_vals) 610 with core.ensure_no_leaks(trace): del trace 611 return outs File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:625, in _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals) 621 in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) 622 with (core.set_current_trace(trace), 623 core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), 624 core.add_spmd_axis_names(axis_data.spmd_name)): --> 625 outs = f(*in_tracers) 626 out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests 627 out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis), 628 range(len(outs)), outs, out_dim_dests) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:340, in flatten_fun_for_vmap(f, store, in_tree, *args_flat) 336 @lu.transformation_with_aux2 337 def flatten_fun_for_vmap(f: Callable, 338 store: lu.Store, in_tree: PyTreeDef, *args_flat): 339 py_args, py_kwargs = tree_unflatten(in_tree, args_flat) --> 340 ans = f(*py_args, **py_kwargs) 341 ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable) 342 store.store(out_tree) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py:402, in _get_result_paths_thunk(_fun, _store, *args, **kwargs) 400 @transformation_with_aux2 401 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs): --> 402 ans = _fun(*args, **kwargs) 403 result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans)) 404 if _store: 405 # In some instances a lu.WrappedFun is called multiple times, e.g., 406 # the bwd function in a custom_vjp [... skipping hidden 1 frame] File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/api.py:1127, in vmap.<locals>.vmap_f(*args, **kwargs) 1124 try: 1125 axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name, 1126 explicit_mesh_axis) -> 1127 out_flat = batching.batch( 1128 flat_fun, axis_data, in_axes_flat, 1129 lambda: flatten_axes("vmap out_axes", out_tree(), out_axes) 1130 ).call_wrapped(*args_flat) 1131 except batching.SpecMatchError as e: 1132 out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py:211, in WrappedFun.call_wrapped(self, *args, **kwargs) 209 def call_wrapped(self, *args, **kwargs): 210 """Calls the transformed function""" --> 211 return self.f_transformed(*args, **kwargs) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:609, in _batch_outer(f, axis_data, in_dims, *in_vals) 607 tag = TraceTag() 608 with source_info_util.transform_name_stack('vmap'): --> 609 outs, trace = f(tag, in_dims, *in_vals) 610 with core.ensure_no_leaks(trace): del trace 611 return outs File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:625, in _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals) 621 in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) 622 with (core.set_current_trace(trace), 623 core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), 624 core.add_spmd_axis_names(axis_data.spmd_name)): --> 625 outs = f(*in_tracers) 626 out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests 627 out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis), 628 range(len(outs)), outs, out_dim_dests) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:340, in flatten_fun_for_vmap(f, store, in_tree, *args_flat) 336 @lu.transformation_with_aux2 337 def flatten_fun_for_vmap(f: Callable, 338 store: lu.Store, in_tree: PyTreeDef, *args_flat): 339 py_args, py_kwargs = tree_unflatten(in_tree, args_flat) --> 340 ans = f(*py_args, **py_kwargs) 341 ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable) 342 store.store(out_tree) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py:402, in _get_result_paths_thunk(_fun, _store, *args, **kwargs) 400 @transformation_with_aux2 401 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs): --> 402 ans = _fun(*args, **kwargs) 403 result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans)) 404 if _store: 405 # In some instances a lu.WrappedFun is called multiple times, e.g., 406 # the bwd function in a custom_vjp File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/laplace_approximation.py:41, in _initial_guess(xi_ti, y_ti, dist, link) 40 def _initial_guess(xi_ti, y_ti, dist, link=default_link): ---> 41 result = minimize( 42 lambda s_ti: -dist(s_ti, xi_ti).log_prob(y_ti).sum(), 43 jnp.atleast_1d(default_link(y_ti)), 44 method="BFGS", 45 ) 46 return jnp.squeeze(result.x) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/scipy/optimize/minimize.py:109, in minimize(fun, x0, args, method, tol, options) 106 fun_with_args = lambda x: fun(x, *args) 108 if method.lower() == 'bfgs': --> 109 results = minimize_bfgs(fun_with_args, x0, **options) 110 success = results.converged & jnp.logical_not(results.failed) 111 return OptimizeResults(x=results.x_k, 112 success=success, 113 status=results.status, (...) 118 njev=results.ngev, 119 nit=results.k) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/scipy/optimize/bfgs.py:168, in minimize_bfgs(fun, x0, maxiter, norm, gtol, line_search_maxiter) 157 state = state._replace( 158 converged=converged, 159 k=state.k + 1, (...) 164 old_old_fval=state.f_k, 165 ) 166 return state --> 168 state = lax.while_loop(cond_fun, body_fun, state) 169 status = jnp.where( 170 state.converged, 171 0, # converged (...) 180 ) 181 ) 182 state = state._replace(status=status) [... skipping hidden 1 frame] File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1636, in while_loop(cond_fun, body_fun, init_val) 1633 init_vals, new_body_consts = partition_list(move_to_const, init_vals) 1634 body_consts = [*new_body_consts, *body_consts] -> 1636 outs = while_p.bind(*cond_consts, *body_consts, *init_vals, 1637 cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, 1638 body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) 1640 if any(move_to_const): 1641 outs = pe.merge_lists(move_to_const, outs, new_body_consts) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params) 529 def bind(self, *args, **params): 530 args = args if self.skip_canonicalization else map(canonicalize_value, args) --> 531 return self._true_bind(*args, **params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params) 549 trace_ctx.set_trace(eval_trace) 550 try: --> 551 return self.bind_with_trace(prev_trace, args, params) 552 finally: 553 trace_ctx.set_trace(prev_trace) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params) 555 def bind_with_trace(self, trace, args, params): --> 556 return trace.process_primitive(self, args, params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:496, in BatchTrace.process_primitive(self, p, tracers, params) 494 else: 495 with core.set_current_trace(self.parent_trace): --> 496 val_out, dim_out = fancy_primitive_batchers[p]( 497 self.axis_data, vals_in, dims_in, **params) 498 elif args_not_mapped: 499 # no-op shortcut 500 return p.bind_with_trace(self.parent_trace, vals_in, params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1769, in _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr) 1766 assert new_axis is not batching.not_mapped 1767 new_init.append(batching.moveaxis(x, old_axis, new_axis)) -> 1769 outs = while_p.bind(*(cconsts + bconsts + new_init), 1770 cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched, 1771 body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched) 1772 return outs, carry_dims File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params) 529 def bind(self, *args, **params): 530 args = args if self.skip_canonicalization else map(canonicalize_value, args) --> 531 return self._true_bind(*args, **params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params) 549 trace_ctx.set_trace(eval_trace) 550 try: --> 551 return self.bind_with_trace(prev_trace, args, params) 552 finally: 553 trace_ctx.set_trace(prev_trace) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params) 555 def bind_with_trace(self, trace, args, params): --> 556 return trace.process_primitive(self, args, params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:496, in BatchTrace.process_primitive(self, p, tracers, params) 494 else: 495 with core.set_current_trace(self.parent_trace): --> 496 val_out, dim_out = fancy_primitive_batchers[p]( 497 self.axis_data, vals_in, dims_in, **params) 498 elif args_not_mapped: 499 # no-op shortcut 500 return p.bind_with_trace(self.parent_trace, vals_in, params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1769, in _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr) 1766 assert new_axis is not batching.not_mapped 1767 new_init.append(batching.moveaxis(x, old_axis, new_axis)) -> 1769 outs = while_p.bind(*(cconsts + bconsts + new_init), 1770 cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched, 1771 body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched) 1772 return outs, carry_dims File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params) 529 def bind(self, *args, **params): 530 args = args if self.skip_canonicalization else map(canonicalize_value, args) --> 531 return self._true_bind(*args, **params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params) 549 trace_ctx.set_trace(eval_trace) 550 try: --> 551 return self.bind_with_trace(prev_trace, args, params) 552 finally: 553 trace_ctx.set_trace(prev_trace) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params) 555 def bind_with_trace(self, trace, args, params): --> 556 return trace.process_primitive(self, args, params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:1060, in EvalTrace.process_primitive(self, primitive, args, params) 1058 args = map(full_lower, args) 1059 check_eval_args(args) -> 1060 return primitive.impl(*args, **params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/dispatch.py:88, in apply_primitive(prim, *args, **params) 86 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) 87 try: ---> 88 outs = fun(*args) 89 finally: 90 lib.jax_jit.swap_thread_local_state_disable_jit(prev) [... skipping hidden 1 frame] File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:334, in _cpp_pjit.<locals>.cache_miss(*args, **kwargs) 329 if config.no_tracing.value: 330 raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " 331 "`jit`, but 'no_tracing' is set") 333 (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, box_data, --> 334 executable, pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs) 336 maybe_fastpath_data = _get_fastpath_data( 337 executable, out_tree, args_flat, out_flat, attrs_tracked, box_data, 338 jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, pgle_profiler) 340 return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:197, in _python_pjit_helper(fun, jit_info, *args, **kwargs) 195 out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) 196 else: --> 197 out_flat = pjit_p.bind(*args_flat, **p.params) 198 compiled = None 199 profiler = None File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params) 529 def bind(self, *args, **params): 530 args = args if self.skip_canonicalization else map(canonicalize_value, args) --> 531 return self._true_bind(*args, **params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params) 549 trace_ctx.set_trace(eval_trace) 550 try: --> 551 return self.bind_with_trace(prev_trace, args, params) 552 finally: 553 trace_ctx.set_trace(prev_trace) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params) 555 def bind_with_trace(self, trace, args, params): --> 556 return trace.process_primitive(self, args, params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:1060, in EvalTrace.process_primitive(self, primitive, args, params) 1058 args = map(full_lower, args) 1059 check_eval_args(args) -> 1060 return primitive.impl(*args, **params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1928, in _pjit_call_impl(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args) 1920 donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) 1921 cache_key = pxla.JitGlobalCppCacheKeys( 1922 donate_argnums=donated_argnums, donate_argnames=None, 1923 device=None, backend=None, (...) 1926 in_layouts_treedef=None, in_layouts_leaves=in_layouts, 1927 out_layouts_treedef=None, out_layouts_leaves=out_layouts) -> 1928 return xc._xla.pjit( 1929 name, f, call_impl_cache_miss, [], [], cache_key, 1930 tree_util.dispatch_registry, pxla.cc_shard_arg, 1931 _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1905, in _pjit_call_impl.<locals>.call_impl_cache_miss(*args_, **kwargs_) 1904 def call_impl_cache_miss(*args_, **kwargs_): -> 1905 out_flat, compiled, pgle_profiler = _pjit_call_impl_python( 1906 *args, jaxpr=jaxpr, in_shardings=in_shardings, 1907 out_shardings=out_shardings, in_layouts=in_layouts, 1908 out_layouts=out_layouts, donated_invars=donated_invars, 1909 ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, 1910 inline=inline, compiler_options_kvs=compiler_options_kvs) 1911 fastpath_data = _get_fastpath_data( 1912 compiled, tree_structure(out_flat), args, out_flat, [], [], 1913 jaxpr.effects, jaxpr.consts, None, pgle_profiler) 1914 return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1862, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args) 1850 compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items()) 1851 # Passing mutable PGLE profile here since it should be extracted by JAXPR to 1852 # initialize the fdo_profile compile option. 1853 compiled = _resolve_and_lower( 1854 args, jaxpr=jaxpr, in_shardings=in_shardings, 1855 out_shardings=out_shardings, in_layouts=in_layouts, 1856 out_layouts=out_layouts, donated_invars=donated_invars, 1857 ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, 1858 inline=inline, lowering_platforms=None, 1859 lowering_parameters=mlir.LoweringParameters(), 1860 pgle_profiler=pgle_profiler, 1861 compiler_options_kvs=compiler_options_kvs, -> 1862 ).compile() 1864 # This check is expensive so only do it if enable_checks is on. 1865 if compiled._auto_spmd_lowering and config.enable_checks.value: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:2467, in MeshComputation.compile(self, compiler_options) 2465 compiler_options_kvs = self._compiler_options_kvs + t_compiler_options 2466 if self._executable is None or compiler_options_kvs: -> 2467 executable = UnloadedMeshExecutable.from_hlo( 2468 self._name, self._hlo, **self.compile_args, 2469 compiler_options_kvs=compiler_options_kvs) 2470 if not compiler_options_kvs: 2471 self._executable = executable File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:3009, in UnloadedMeshExecutable.from_hlo(***failed resolving arguments***) 3006 break 3008 util.test_event("pxla_cached_compilation") -> 3009 xla_executable = _cached_compilation( 3010 hlo, name, mesh, spmd_lowering, 3011 tuple_args, auto_spmd_lowering, allow_prop_to_inputs, 3012 allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, 3013 compiler_options_kvs, pgle_profiler) 3015 if auto_spmd_lowering: 3016 assert mesh is not None File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:2800, in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_kvs, pgle_profiler) 2792 compile_options = create_compile_options( 2793 computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, 2794 allow_prop_to_inputs, allow_prop_to_outputs, backend, 2795 dev, pmap_nreps, compiler_options) 2797 with dispatch.log_elapsed_time( 2798 "Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec", 2799 fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): -> 2800 xla_executable = compiler.compile_or_get_cached( 2801 backend, computation, dev, compile_options, host_callbacks, 2802 da, pgle_profiler) 2803 return xla_executable File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/compiler.py:447, in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks, executable_devices, pgle_profiler) 445 else: 446 log_persistent_cache_miss(module_name, cache_key) --> 447 return _compile_and_write_cache( 448 backend, 449 computation, 450 executable_devices, 451 compile_options, 452 host_callbacks, 453 module_name, 454 cache_key, 455 ) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/compiler.py:719, in _compile_and_write_cache(backend, computation, executable_devices, compile_options, host_callbacks, module_name, cache_key) 709 def _compile_and_write_cache( 710 backend: xc.Client, 711 computation: ir.Module, (...) 716 cache_key: str, 717 ) -> xc.LoadedExecutable: 718 start_time = time.monotonic() --> 719 executable = backend_compile( 720 backend, computation, executable_devices, compile_options, host_callbacks 721 ) 722 compile_time = time.monotonic() - start_time 723 _cache_write( 724 cache_key, compile_time, module_name, backend, executable, host_callbacks 725 ) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/profiler.py:354, in annotate_function.<locals>.wrapper(*args, **kwargs) 351 @wraps(func) 352 def wrapper(*args, **kwargs): 353 with TraceAnnotation(name, **decorator_kwargs): --> 354 return func(*args, **kwargs) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/compiler.py:335, in backend_compile(backend, module, executable_devices, options, host_callbacks) 326 return backend.compile( 327 built_c, 328 executable_devices=executable_devices, # type: ignore 329 compile_options=options, 330 host_callbacks=host_callbacks, 331 ) 332 # Some backends don't have `host_callbacks` option yet 333 # TODO(sharadmv): remove this fallback when all backends allow `compile` 334 # to take in `host_callbacks` --> 335 return backend.compile( 336 built_c, executable_devices=executable_devices, compile_options=options) # type: ignore 337 except xc.XlaRuntimeError as e: 338 for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: KeyboardInterrupt:
def asymptotic_det_meis(Y, pgssm, prop_la, N_iter, N_samples, key, M: int):
key, *subkeys = jrn.split(key, 1 + M)
proposals = [
MEIS(Y, pgssm, prop_la.z, prop_la.Omega, N_iter, N_samples, sk)[0]
for sk in subkeys
]
modes = jnp.array([posterior_mode(proposal).reshape(-1) for proposal in proposals])
cov = jnp.cov(modes, rowvar=False) * N_samples
_, logdet = jnp.linalg.slogdet(cov)
return logdet
def asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M: int):
key, *subkeys = jrn.split(key, 1 + M)
proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys]
modes = jnp.array([proposal.mean[:, 0] for proposal in proposals])
cov = jnp.cov(modes, rowvar=False) * N_samples
_, logdet = jnp.linalg.slogdet(cov)
return logdet
def asymptotic_variance(n: int, key: jrn.PRNGKey):
pgssm = _model(n, I0=1000)
key, subkey = jrn.split(key)
_, (Y,) = simulate_pgssm(pgssm, 1, subkey)
prop_la, _ = LA(Y, pgssm, N_iter)
key, *sks = jrn.split(key, 1 + 2 * M)
sks_meis = sks[:M]
sks_cem = sks[M:]
logdet_cem = asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M=len(sks_cem))
logdet_meis = asymptotic_det_meis(
Y, pgssm, prop_la, N_iter, N_samples, key, M=len(sks_meis)
)
result = pd.Series(
{
"n": n,
"N_samples": N_samples,
"N_iter": N_iter,
"log_DET_CEM": logdet_cem,
"log_DET_MEIS": logdet_meis,
"ARE": jnp.exp(logdet_cem - logdet_meis),
}
)
return resultkey = jrn.PRNGKey(140235293)
ns_are = jnp.repeat(jnp.array([1, 2, 3, 4, 5]), K)
key, *keys_are = jrn.split(key, len(ns_are) + 1)are_meis_cem_ssm_path = here("data/figures/are_meis_cem_ssms.csv")
if not are_meis_cem_ssm_path.exists():
results_are = pd.DataFrame(
[
asymptotic_variance(n, k)
for n, k in tqdm(zip(ns_are, keys_are), total=len(ns_are))
]
)
results_are.to_csv(are_meis_cem_ssm_path, index=False)
results_are--------------------------------------------------------------------------- KeyError Traceback (most recent call last) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2294, in _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack, arg_names, result_names) 2293 try: -> 2294 func_op = ctx.cached_primitive_lowerings[key] 2295 except KeyError: KeyError: (None, let norm = { lambda ; a:f64[1000,4]. let b:f64[1000,4] = mul a a c:f64[1000] = reduce_sum[axes=(1,)] b d:f64[1000] = sqrt c in (d,) } in let tril = { lambda ; e:f64[2,2]. let f:i32[2,2] = iota[dimension=0 dtype=int32 shape=(2, 2) sharding=None] g:i32[2,2] = add f 0:i32[] h:i32[2,2] = iota[dimension=1 dtype=int32 shape=(2, 2) sharding=None] i:bool[2,2] = ge g h j:f64[2,2] = broadcast_in_dim[ broadcast_dimensions=() shape=(2, 2) sharding=None ] 0.0:f64[] k:f64[2,2] = select_n i j e in (k,) } in let diagonal = { lambda ; l:f64[2,2]. let m:i64[2,2] = iota[dimension=0 dtype=int64 shape=(2, 2) sharding=None] n:i64[2,2] = iota[dimension=1 dtype=int64 shape=(2, 2) sharding=None] o:i64[2,2] = add m 0:i64[] p:bool[2,2] = eq o n q:f64[2,2] = convert_element_type[new_dtype=float64 weak_type=False] p r:i32[] = platform_index[platforms=(('mosaic',), None)] s:f64[2] = cond[ branches=( { lambda ; t:f64[2,2] u:f64[2,2]. let v:f64[2,2] = mul t u w:f64[2] = reduce_sum[axes=(0,)] v in (w,) } { lambda ; x:f64[2,2] y:f64[2,2]. let z:i64[2] = iota[dimension=0 dtype=int64 shape=(2,) sharding=None] ba:i64[2] = iota[dimension=0 dtype=int64 shape=(2,) sharding=None] bb:bool[2] = lt z 0:i64[] bc:i64[2] = add z 2:i64[] bd:i64[2] = select_n bb z bc be:bool[2] = lt ba 0:i64[] bf:i64[2] = add ba 2:i64[] bg:i64[2] = select_n be ba bf bh:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] bd bi:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] bg bj:i32[2,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2, 1) sharding=None ] bh bk:i32[2,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2, 1) sharding=None ] bi bl:i32[2,2] = concatenate[dimension=1] bj bk bm:f64[2] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=()) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] y bl in (bm,) } ) branches_platforms=(('mosaic',), None) ] r q l in (s,) } in let _where = { lambda ; bn:bool[1000] bo:f64[1000] bp:f64[1000]. let bq:f64[1000] = select_n bn bp bo in (bq,) } in let polyval = { lambda ; br:f64[4] bs:f64[1000]. let bt:f64[1000] = broadcast_in_dim[ broadcast_dimensions=() shape=(1000,) sharding=None ] 0.0:f64[] bu:f64[1000] = scan[ _split_transpose=False jaxpr={ lambda ; bv:f64[1000] bw:f64[1000] bx:f64[]. let by:f64[1000] = mul bw bv bz:f64[1000] = add by bx in (bz,) } length=4 linear=(False, False, False) num_carry=1 num_consts=1 reverse=False unroll=16 ] bs bt br in (bu,) } in let polyval1 = { lambda ; ca:f64[5] cb:f64[1000]. let cc:f64[1000] = broadcast_in_dim[ broadcast_dimensions=() shape=(1000,) sharding=None ] 0.0:f64[] cd:f64[1000] = scan[ _split_transpose=False jaxpr={ lambda ; ce:f64[1000] cf:f64[1000] cg:f64[]. let ch:f64[1000] = mul cf ce ci:f64[1000] = add ch cg in (ci,) } length=5 linear=(False, False, False) num_carry=1 num_consts=1 reverse=False unroll=16 ] cb cc ca in (cd,) } in let _where1 = { lambda ; cj:bool[1] ck:f64[1000] cl:f64[1000]. let cm:bool[1000] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(1000,) sharding=None ] cj cn:f64[1000] = select_n cm cl ck in (cn,) } in let isinf = { lambda ; co:f64[1000]. let cp:f64[1000] = abs co cq:bool[1000] = eq cp inf:f64[] in (cq,) } in let _where2 = { lambda ; cr:bool[1000] cs:f64[] ct:f64[1000]. let cu:f64[1000] = broadcast_in_dim[ broadcast_dimensions=() shape=(1000,) sharding=None ] cs cv:f64[1000] = select_n cr ct cu in (cv,) } in let jaxpr = { lambda ; t:f64[2,2] u:f64[2,2]. let v:f64[2,2] = mul t u w:f64[2] = reduce_sum[axes=(0,)] v in (w,) } in let jaxpr1 = { lambda ; x:f64[2,2] y:f64[2,2]. let z:i64[2] = iota[dimension=0 dtype=int64 shape=(2,) sharding=None] ba:i64[2] = iota[dimension=0 dtype=int64 shape=(2,) sharding=None] bb:bool[2] = lt z 0:i64[] bc:i64[2] = add z 2:i64[] bd:i64[2] = select_n bb z bc be:bool[2] = lt ba 0:i64[] bf:i64[2] = add ba 2:i64[] bg:i64[2] = select_n be ba bf bh:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] bd bi:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] bg bj:i32[2,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2, 1) sharding=None ] bh bk:i32[2,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2, 1) sharding=None ] bi bl:i32[2,2] = concatenate[dimension=1] bj bk bm:f64[2] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=()) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1) unique_indices=False ] y bl in (bm,) } in let atleast_2d = { lambda ; cw:f64[2,2]. let in (cw,) } in let _where3 = { lambda ; cx:bool[1] cy:f64[1000] cz:f64[1]. let da:bool[1000] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(1000,) sharding=None ] cx db:f64[1000] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(1000,) sharding=None ] cz dc:f64[1000] = select_n da db cy in (dc,) } in let jaxpr2 = { lambda ; bv:f64[1000] bw:f64[1000] bx:f64[]. let by:f64[1000] = mul bw bv bz:f64[1000] = add by bx in (bz,) } in let jaxpr3 = { lambda ; ce:f64[1000] cf:f64[1000] cg:f64[]. let ch:f64[1000] = mul cf ce ci:f64[1000] = add ch cg in (ci,) } in let _where4 = { lambda ; bn:bool[1000] bo:f64[1000] bp:f64[1000]. let bq:f64[1000] = select_n bn bp bo in (bq,) } in { lambda ; dd:u32[2] de:f64[4] df:f64[5] dg:f64[4] dh:f64[5] di:f64[2,2] dj:f64[2,2] dk:f64[1,2,1] dl:f64[1,2,2] dm:f64[1,1,1] dn:f64[2,1,2] do:f64[2,1] dp:f64[2,1] dq:i64[] dr:f64[2,2] ds:f64[2,2,2] dt:f64[1,2,2] du:f64[1,2,2] dv:f64[4000]. let dw:i64[] = add dq 1:i64[] dx:key<fry>[] = random_wrap[impl=fry] dd dy:key<fry>[2] = random_split[shape=(2,)] dx dz:u32[2,2] = random_unwrap dy ea:u32[1,2] = slice[ limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1) ] dz eb:u32[2] = squeeze[dimensions=(0,)] ea ec:f64[2] = broadcast_in_dim[ broadcast_dimensions=() shape=(2,) sharding=None ] 0.0:f64[] ed:i64[2,2] = iota[dimension=0 dtype=int64 shape=(2, 2) sharding=None] ee:i64[2,2] = iota[dimension=1 dtype=int64 shape=(2, 2) sharding=None] ef:i64[2,2] = add ed 0:i64[] eg:bool[2,2] = eq ef ee eh:f64[2,2] = convert_element_type[new_dtype=float64 weak_type=False] eg ei:f64[2,2] = pjit[ name=cholesky jaxpr={ lambda ; eh:f64[2,2]. let ej:f64[2,2] = transpose[permutation=(1, 0)] eh ek:f64[2,2] = add eh ej el:f64[2,2] = div ek 2.0:f64[] em:f64[2,2] = cholesky el en:i32[2,2] = iota[dimension=0 dtype=int32 shape=(2, 2) sharding=None] eo:i32[2,2] = add en 0:i32[] ep:i32[2,2] = iota[dimension=1 dtype=int32 shape=(2, 2) sharding=None] eq:bool[2,2] = ge eo ep er:f64[2,2] = broadcast_in_dim[ broadcast_dimensions=() shape=(2, 2) sharding=None ] 0.0:f64[] ei:f64[2,2] = select_n eq er em in (ei,) } ] eh es:key<fry>[] = random_wrap[impl=fry] eb et:f64[4000] = pjit[ name=_normal jaxpr={ lambda ; es:key<fry>[]. let et:f64[4000] = pjit[ name=_normal_real jaxpr={ lambda ; es:key<fry>[]. let eu:f64[4000] = pjit[ name=_uniform jaxpr={ lambda ; es:key<fry>[] ev:f64[] ew:f64[]. let ex:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] ev ey:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] ew ez:u64[4000] = random_bits[bit_width=64 shape=(4000,)] es fa:u64[4000] = shift_right_logical ez 12:u64[] fb:u64[4000] = or fa 4607182418800017408:u64[] fc:f64[4000] = bitcast_convert_type[new_dtype=float64] fb fd:f64[4000] = sub fc 1.0:f64[] fe:f64[1] = sub ey ex ff:f64[4000] = mul fd fe fg:f64[4000] = add ff ex eu:f64[4000] = max ex fg in (eu,) } ] es -0.9999999999999999:f64[] 1.0:f64[] fh:f64[4000] = erf_inv eu et:f64[4000] = mul 1.4142135623730951:f64[] fh in (et,) } ] es in (et,) } ] es fi:f64[4000] = mul et 1.0:f64[] fj:f64[4000] = add fi 0.0:f64[] fk:f64[4000] = mul fj 1.0:f64[] fl:f64[4000] = add fk 0.0:f64[] fm:f64[2000,2] = reshape[dimensions=None new_sizes=(2000, 2) sharding=None] fl fn:f64[1000,2,2] = reshape[ dimensions=None new_sizes=(1000, 2, 2) sharding=None ] fm fo:f64[1000,2,2,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1, 2) shape=(1000, 2, 2, 1) sharding=None ] fn fp:f64[2,2] = pjit[ name=tril jaxpr={ lambda ; ei:f64[2,2]. let fq:i32[2,2] = iota[dimension=0 dtype=int32 shape=(2, 2) sharding=None] fr:i32[2,2] = add fq 0:i32[] fs:i32[2,2] = iota[dimension=1 dtype=int32 shape=(2, 2) sharding=None] ft:bool[2,2] = ge fr fs fu:f64[2,2] = broadcast_in_dim[ broadcast_dimensions=() shape=(2, 2) sharding=None ] 0.0:f64[] fp:f64[2,2] = select_n ft fu ei in (fp,) } ] ei fv:f64[2,1000,2,1] = dot_general[ dimension_numbers=(([1], [2]), ([], [])) preferred_element_type=float64 ] fp fo fw:f64[1000,2,2,1] = transpose[permutation=(1, 2, 0, 3)] fv fx:f64[1000,2,2] = squeeze[dimensions=(3,)] fw fy:f64[1,1,2] = broadcast_in_dim[ broadcast_dimensions=(2,) shape=(1, 1, 2) sharding=None ] ec fz:f64[1000,2,2] = add fx fy ga:f64[2,2,1000] = dot_general[ dimension_numbers=(([2], [2]), ([0], [1])) preferred_element_type=float64 ] ds fz gb:f64[1000,2,2] = transpose[permutation=(2, 0, 1)] ga gc:f64[2,1000,2] = transpose[permutation=(1, 0, 2)] gb gd:i64[2,2] = iota[dimension=0 dtype=int64 shape=(2, 2) sharding=None] ge:i64[2,2] = iota[dimension=1 dtype=int64 shape=(2, 2) sharding=None] gf:i64[2,2] = add gd 0:i64[] gg:bool[2,2] = eq gf ge gh:f64[2,2] = convert_element_type[new_dtype=float64 weak_type=False] gg gi:f64[1,2,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 2) sharding=None ] gh gj:f64[2,2,2] = concatenate[dimension=0] gi dt gk:i64[2,2] = iota[dimension=0 dtype=int64 shape=(2, 2) sharding=None] gl:i64[2,2] = iota[dimension=1 dtype=int64 shape=(2, 2) sharding=None] gm:i64[2,2] = add gk 0:i64[] gn:bool[2,2] = eq gm gl go:f64[2,2] = convert_element_type[new_dtype=float64 weak_type=False] gn gp:f64[1,2,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 2) sharding=None ] go gq:f64[2,2,2] = concatenate[dimension=0] gp du gr:f64[1000,2] = broadcast_in_dim[ broadcast_dimensions=() shape=(1000, 2) sharding=None ] 0.0:f64[] _:f64[1000,2] gs:f64[2,1000,2] = scan[ _split_transpose=False jaxpr={ lambda ; gt:f64[1000,2] gu:f64[1000,2] gv:f64[2,2] gw:f64[2,2]. let gx:f64[1000,2] = pjit[ name=_solve_triangular jaxpr={ lambda ; gv:f64[2,2] gt:f64[1000,2]. let gy:f64[1000,2,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1) shape=(1000, 2, 1) sharding=None ] gt gz:f64[2,1,1000] = transpose[permutation=(1, 2, 0)] gy ha:f64[2,1000] = reshape[ dimensions=None new_sizes=(2, 1000) sharding=None ] gz hb:f64[2,1000] = triangular_solve[ conjugate_a=False left_side=True lower=True transpose_a=False unit_diagonal=False ] gv ha hc:f64[2,1,1000] = reshape[ dimensions=None new_sizes=(2, 1, 1000) sharding=None ] hb hd:f64[2,1,1000] = slice[ limit_indices=(2, 1, 1000) start_indices=(0, 0, 0) strides=None ] hc he:f64[1000,2,1] = transpose[permutation=(2, 0, 1)] hd gx:f64[1000,2] = squeeze[dimensions=(2,)] he in (gx,) } ] gv gt hf:f64[2,1000] = dot_general[ dimension_numbers=(([1], [1]), ([], [])) preferred_element_type=float64 ] gw gx hg:f64[1000,2] = transpose[permutation=(1, 0)] hf hh:f64[1000,2] = add hg gu in (hh, hh) } length=2 linear=(False, False, False, False) num_carry=1 num_consts=0 reverse=False unroll=1 ] gr gc gj gq hi:f64[1000,2,2] = transpose[permutation=(1, 0, 2)] gs hj:f64[1,2,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 2) sharding=None ] dr hk:f64[1000,2,2] = add hi hj hl:f64[1000,4] = reshape[dimensions=None new_sizes=(1000, 4) sharding=None] fz hm:f64[1,2,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 2) sharding=None ] dr hn:f64[1,2,2] = mul 2.0:f64[] hm ho:f64[1000,2,2] = sub hn hk hp:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 1.0:f64[] hq:f64[1] = mul 4.0:f64[] hp hr:f64[1000] = pjit[name=norm jaxpr=norm] hl hs:f64[1000] = integer_pow[y=2] hr ht:f64[1] = mul 0.5:f64[] hq hu:f64[1000] = mul 0.5:f64[] hs hv:f64[1000] = igamma ht hu hw:f64[1000] = sub 1.0:f64[] hv hx:f64[1] = mul 0.5:f64[] hq hy:f64[1000] = custom_jvp_call[ name=_igammainv_custom_gradient call_jaxpr={ lambda ; hz:f64[4] ia:f64[5] ib:f64[1] ic:f64[1000]. let id:f64[1000] = sub 1.0:f64[] ic ie:f64[1] = lgamma ib if:f64[1000] = neg ic ig:f64[1000] = log1p if ih:f64[1000] = add ig ie ii:f64[1000] = neg ih ij:f64[1] = sub ib 1.0:f64[] ik:f64[1000] = custom_jvp_call[ name=xlogy call_jaxpr={ lambda ; il:f64[1] im:f64[1000]. let in:bool[1] = ne il 0.0:f64[] io:f64[1000] = log im ip:f64[1000] = mul il io iq:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] ir:f64[1000] = pjit[name=_where jaxpr=_where3] in ip iq in (ir,) } jvp=_xlogy_jvp symbolic_zeros=False ] ij ii is:f64[1000] = square ik it:f64[1000] = mul is ik iu:f64[1000] = square is iv:f64[1] = square ib iw:f64[1] = mul iv ib ix:f64[1] = sub ib 1.0:f64[] iy:f64[1000] = add 1.0:f64[] ik iz:f64[1000] = mul ix iy ja:f64[1] = sub ib 1.0:f64[] jb:f64[1] = mul 3.0:f64[] ib jc:f64[1] = sub jb 5.0:f64[] jd:f64[1] = div jc 2.0:f64[] je:f64[1] = sub ib 2.0:f64[] jf:f64[1000] = div ik 2.0:f64[] jg:f64[1000] = sub je jf jh:f64[1000] = mul ik jg ji:f64[1000] = add jd jh jj:f64[1000] = mul ja ji jk:f64[1] = sub ib 1.0:f64[] jl:f64[1000] = div it 3.0:f64[] jm:f64[1] = mul 3.0:f64[] ib jn:f64[1] = sub jm 5.0:f64[] jo:f64[1000] = mul jn is jp:f64[1000] = div jo 2.0:f64[] jq:f64[1000] = sub jl jp jr:f64[1] = mul 6.0:f64[] ib js:f64[1] = sub iv jr jt:f64[1] = add js 7.0:f64[] ju:f64[1000] = mul jt ik jv:f64[1000] = add jq ju jw:f64[1] = mul 11.0:f64[] iv jx:f64[1] = mul 46.0:f64[] ib jy:f64[1] = sub jw jx jz:f64[1] = add jy 47.0:f64[] ka:f64[1] = div jz 6.0:f64[] kb:f64[1000] = add jv ka kc:f64[1000] = mul jk kb kd:f64[1] = sub ib 1.0:f64[] ke:f64[1000] = neg iu kf:f64[1000] = div ke 4.0:f64[] kg:f64[1] = mul 11.0:f64[] ib kh:f64[1] = sub kg 17.0:f64[] ki:f64[1000] = mul kh it kj:f64[1000] = div ki 6.0:f64[] kk:f64[1000] = add kf kj kl:f64[1] = mul -3.0:f64[] iv km:f64[1] = mul 13.0:f64[] ib kn:f64[1] = add kl km ko:f64[1] = sub kn 13.0:f64[] kp:f64[1000] = mul ko is kq:f64[1000] = add kk kp kr:f64[1] = mul 2.0:f64[] iw ks:f64[1] = mul 25.0:f64[] iv kt:f64[1] = sub kr ks ku:f64[1] = mul 72.0:f64[] ib kv:f64[1] = add kt ku kw:f64[1] = sub kv 61.0:f64[] kx:f64[1000] = mul kw ik ky:f64[1000] = div kx 2.0:f64[] kz:f64[1000] = add kq ky la:f64[1] = mul 25.0:f64[] iw lb:f64[1] = mul 195.0:f64[] iv lc:f64[1] = sub la lb ld:f64[1] = mul 477.0:f64[] ib le:f64[1] = add lc ld lf:f64[1] = sub le 379.0:f64[] lg:f64[1] = div lf 12.0:f64[] lh:f64[1000] = add kz lg li:f64[1000] = mul kd lh lj:f64[1000] = add ii ik lk:f64[1000] = div li ii ll:f64[1000] = add lk kc lm:f64[1000] = div ll ii ln:f64[1000] = div jj ii lo:f64[1000] = add lm ln lp:f64[1000] = add lo iz lq:f64[1000] = div lp ii lr:f64[1000] = add lj lq ls:f64[1000] = neg ih lt:f64[1] = sub 1.0:f64[] ib lu:f64[1000] = neg ih lv:f64[1000] = log lu lw:f64[1000] = mul lt lv lx:f64[1000] = sub ls lw ly:f64[1000] = square lx lz:bool[1000] = gt ih -4.605170185988091:f64[] ma:f64[1000] = neg ih mb:f64[1] = sub 1.0:f64[] ib mc:f64[1000] = log lx md:f64[1000] = mul mb mc me:f64[1000] = sub ma md mf:f64[1] = sub 3.0:f64[] ib mg:f64[1] = mul 2.0:f64[] mf mh:f64[1000] = mul mg lx mi:f64[1000] = add ly mh mj:f64[1] = sub 2.0:f64[] ib mk:f64[1] = sub 3.0:f64[] ib ml:f64[1] = mul mj mk mm:f64[1000] = add mi ml mn:f64[1] = sub 5.0:f64[] ib mo:f64[1000] = mul mn lx mp:f64[1000] = add ly mo mq:f64[1000] = add mp 2.0:f64[] mr:f64[1000] = div mm mq ms:f64[1000] = log mr mt:f64[1000] = sub me ms mu:f64[1000] = pjit[name=_where jaxpr=_where] lz mt lr mv:bool[1000] = ge ih -1.8971199848858813:f64[] mw:f64[1000] = neg ih mx:f64[1] = sub ib 1.0:f64[] my:f64[1000] = custom_jvp_call[ name=xlogy call_jaxpr={ lambda ; mz:f64[1] na:f64[1000]. let nb:bool[1] = ne mz 0.0:f64[] nc:f64[1000] = log na nd:f64[1000] = mul mz nc ne:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] nf:f64[1000] = pjit[name=_where jaxpr=_where3] nb nd ne in (nf,) } jvp=_xlogy_jvp symbolic_zeros=False ] mx lx ng:f64[1000] = add mw my nh:f64[1] = sub 1.0:f64[] ib ni:f64[1000] = add 1.0:f64[] lx nj:f64[1000] = div nh ni nk:f64[1000] = log1p nj nl:f64[1000] = sub ng nk nm:f64[1000] = pjit[name=_where jaxpr=_where] mv nl mu nn:f64[1000] = exp ih no:f64[1000] = sub -0.5772156649015329:f64[] nn np:f64[1000] = exp no nq:f64[1000] = exp np nr:f64[1000] = mul np nq ns:bool[1] = lt ib 0.3:f64[] nt:bool[1000] = ge ih -1.0498221244986778:f64[] nu:bool[1000] = and ns nt nv:f64[1000] = exp nr nw:f64[1000] = mul np nv nx:f64[1000] = pjit[name=_where jaxpr=_where] nu nw nm ny:f64[1000] = exp ih nz:f64[1000] = mul ny id oa:bool[1000] = gt nz 1e-08:f64[] ob:bool[1000] = gt id 1e-05:f64[] oc:bool[1000] = and oa ob od:f64[1] = exp ie oe:f64[1000] = mul ic od of:f64[1000] = mul oe ib og:f64[1] = integer_pow[y=-1] ib oh:f64[1000] = pow of og oi:f64[1000] = neg id oj:f64[1000] = div oi ib ok:f64[1000] = sub oj 0.5772156649015329:f64[] ol:f64[1000] = exp ok om:f64[1000] = pjit[name=_where jaxpr=_where] oc oh ol on:bool[1000] = gt ih -0.5108256237659907:f64[] oo:bool[1000] = ge ih -0.7985076962177716:f64[] op:bool[1] = ge ib 0.3:f64[] oq:bool[1000] = and oo op or:bool[1000] = or on oq os:f64[1] = add ib 1.0:f64[] ot:f64[1000] = div om os ou:f64[1000] = sub 1.0:f64[] ot ov:f64[1000] = div om ou ow:f64[1000] = pjit[name=_where jaxpr=_where] or ov nx ox:f64[1] = sqrt ib oy:bool[1000] = lt ic 0.5:f64[] oz:f64[1000] = log ic pa:f64[1000] = mul -2.0:f64[] oz pb:f64[1000] = sqrt pa pc:f64[1000] = log id pd:f64[1000] = mul -2.0:f64[] pc pe:f64[1000] = sqrt pd pf:f64[1000] = pjit[name=_where jaxpr=_where] oy pb pe pg:f64[1000] = pjit[name=polyval jaxpr=polyval] hz pf ph:f64[1000] = pjit[name=polyval jaxpr=polyval1] ia pf pi:f64[1000] = div pg ph pj:f64[1000] = sub pf pi pk:bool[1000] = lt ic 0.5:f64[] pl:f64[1000] = neg pj pm:f64[1000] = pjit[name=_where jaxpr=_where] pk pl pj pn:f64[1000] = square pm po:f64[1000] = mul pn pm pp:f64[1000] = square pn pq:f64[1000] = mul pp pm pr:f64[1000] = mul pm ox ps:f64[1000] = add ib pr pt:f64[1000] = sub pn 1.0:f64[] pu:f64[1000] = div pt 3.0:f64[] pv:f64[1000] = add ps pu pw:f64[1000] = mul 7.0:f64[] pm px:f64[1000] = sub po pw py:f64[1] = mul 36.0:f64[] ox pz:f64[1000] = div px py qa:f64[1000] = add pv pz qb:f64[1000] = mul 3.0:f64[] pp qc:f64[1000] = mul 7.0:f64[] pn qd:f64[1000] = add qb qc qe:f64[1000] = sub qd 16.0:f64[] qf:f64[1] = mul 810.0:f64[] ib qg:f64[1000] = div qe qf qh:f64[1000] = sub qa qg qi:f64[1000] = mul 9.0:f64[] pq qj:f64[1000] = mul 256.0:f64[] po qk:f64[1000] = add qi qj ql:f64[1000] = mul 433.0:f64[] pm qm:f64[1000] = sub qk ql qn:f64[1] = mul 38880.0:f64[] ib qo:f64[1] = mul qn ox qp:f64[1000] = div qm qo qq:f64[1000] = add qh qp qr:f64[1] = sub ib 1.0:f64[] qs:f64[1] = mul ib qr qt:f64[1] = copy qs qu:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 2.0:f64[] qv:f64[1] = custom_jvp_call[ name=_maximum_ call_jaxpr={ lambda ; qw:f64[1] qx:f64[1]. let qy:f64[1] = max qw qx in (qy,) } jvp=_maximum_jvp symbolic_zeros=False ] qu qt qz:f64[1] = neg qv ra:f64[1] = mul qz 2.302585092994046:f64[] rb:bool[1000] = le ih ra rc:f64[1000] = neg ih rd:f64[1] = sub ib 1.0:f64[] re:f64[1000] = custom_jvp_call[ name=xlogy call_jaxpr={ lambda ; rf:f64[1] rg:f64[1000]. let rh:bool[1] = ne rf 0.0:f64[] ri:f64[1000] = log rg rj:f64[1000] = mul rf ri rk:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] rl:f64[1000] = pjit[name=_where jaxpr=_where3] rh rj rk in (rl,) } jvp=_xlogy_jvp symbolic_zeros=False ] rd rc rm:f64[1000] = square re rn:f64[1000] = mul rm re ro:f64[1000] = square rm rp:f64[1] = square ib rq:f64[1] = mul rp ib rr:f64[1] = sub ib 1.0:f64[] rs:f64[1000] = add 1.0:f64[] re rt:f64[1000] = mul rr rs ru:f64[1] = sub ib 1.0:f64[] rv:f64[1] = mul 3.0:f64[] ib rw:f64[1] = sub rv 5.0:f64[] rx:f64[1] = div rw 2.0:f64[] ry:f64[1] = sub ib 2.0:f64[] rz:f64[1000] = div re 2.0:f64[] sa:f64[1000] = sub ry rz sb:f64[1000] = mul re sa sc:f64[1000] = add rx sb sd:f64[1000] = mul ru sc se:f64[1] = sub ib 1.0:f64[] sf:f64[1000] = div rn 3.0:f64[] sg:f64[1] = mul 3.0:f64[] ib sh:f64[1] = sub sg 5.0:f64[] si:f64[1000] = mul sh rm sj:f64[1000] = div si 2.0:f64[] sk:f64[1000] = sub sf sj sl:f64[1] = mul 6.0:f64[] ib sm:f64[1] = sub rp sl sn:f64[1] = add sm 7.0:f64[] so:f64[1000] = mul sn re sp:f64[1000] = add sk so sq:f64[1] = mul 11.0:f64[] rp sr:f64[1] = mul 46.0:f64[] ib ss:f64[1] = sub sq sr st:f64[1] = add ss 47.0:f64[] su:f64[1] = div st 6.0:f64[] sv:f64[1000] = add sp su sw:f64[1000] = mul se sv sx:f64[1] = sub ib 1.0:f64[] sy:f64[1000] = neg ro sz:f64[1000] = div sy 4.0:f64[] ta:f64[1] = mul 11.0:f64[] ib tb:f64[1] = sub ta 17.0:f64[] tc:f64[1000] = mul tb rn td:f64[1000] = div tc 6.0:f64[] te:f64[1000] = add sz td tf:f64[1] = mul -3.0:f64[] rp tg:f64[1] = mul 13.0:f64[] ib th:f64[1] = add tf tg ti:f64[1] = sub th 13.0:f64[] tj:f64[1000] = mul ti rm tk:f64[1000] = add te tj tl:f64[1] = mul 2.0:f64[] rq tm:f64[1] = mul 25.0:f64[] rp tn:f64[1] = sub tl tm to:f64[1] = mul 72.0:f64[] ib tp:f64[1] = add tn to tq:f64[1] = sub tp 61.0:f64[] tr:f64[1000] = mul tq re ts:f64[1000] = div tr 2.0:f64[] tt:f64[1000] = add tk ts tu:f64[1] = mul 25.0:f64[] rq tv:f64[1] = mul 195.0:f64[] rp tw:f64[1] = sub tu tv tx:f64[1] = mul 477.0:f64[] ib ty:f64[1] = add tw tx tz:f64[1] = sub ty 379.0:f64[] ua:f64[1] = div tz 12.0:f64[] ub:f64[1000] = add tt ua uc:f64[1000] = mul sx ub ud:f64[1000] = add rc re ue:f64[1000] = div uc rc uf:f64[1000] = add ue sw ug:f64[1000] = div uf rc uh:f64[1000] = div sd rc ui:f64[1000] = add ug uh uj:f64[1000] = add ui rt uk:f64[1000] = div uj rc ul:f64[1000] = add ud uk um:f64[1000] = neg ih un:f64[1] = sub ib 1.0:f64[] uo:f64[1000] = custom_jvp_call[ name=xlogy call_jaxpr={ lambda ; up:f64[1] uq:f64[1000]. let ur:bool[1] = ne up 0.0:f64[] us:f64[1000] = log uq ut:f64[1000] = mul up us uu:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] uv:f64[1000] = pjit[name=_where jaxpr=_where3] ur ut uu in (uv,) } jvp=_xlogy_jvp symbolic_zeros=False ] un qq uw:f64[1000] = add um uo ux:f64[1] = sub 1.0:f64[] ib uy:f64[1000] = add 1.0:f64[] qq uz:f64[1000] = div ux uy va:f64[1000] = log1p uz vb:f64[1000] = sub uw va vc:f64[1000] = neg ih vd:f64[1] = sub ib 1.0:f64[] ve:f64[1000] = custom_jvp_call[ name=xlogy call_jaxpr={ lambda ; vf:f64[1] vg:f64[1000]. let vh:bool[1] = ne vf 0.0:f64[] vi:f64[1000] = log vg vj:f64[1000] = mul vf vi vk:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] vl:f64[1000] = pjit[name=_where jaxpr=_where3] vh vj vk in (vl,) } jvp=_xlogy_jvp symbolic_zeros=False ] vd vb vm:f64[1000] = add vc ve vn:f64[1] = sub 1.0:f64[] ib vo:f64[1000] = add 1.0:f64[] vb vp:f64[1000] = div vn vo vq:f64[1000] = log1p vp vr:f64[1000] = sub vm vq vs:f64[1000] = pjit[name=_where jaxpr=_where] rb ul vr vt:f64[1] = mul 3.0:f64[] ib vu:bool[1000] = lt qq vt vv:f64[1000] = pjit[name=_where jaxpr=_where] vu qq vs vw:bool[1] = ge ib 500.0:f64[] vx:f64[1000] = div qq ib vy:f64[1000] = sub 1.0:f64[] vx vz:f64[1000] = abs vy wa:bool[1000] = lt vz 1e-06:f64[] wb:bool[1000] = and vw wa wc:f64[1000] = pjit[name=_where jaxpr=_where] wb qq vv wd:f64[1000] = log ic we:f64[1] = add ib 1.0:f64[] wf:f64[1] = lgamma we wg:f64[1000] = add wd wf wh:f64[1000] = add wg qq wi:f64[1000] = div wh ib wj:f64[1000] = exp wi wk:f64[1] = add ib 1.0:f64[] wl:f64[1000] = div wj wk wm:f64[1] = add ib 2.0:f64[] wn:f64[1000] = div wj wm wo:f64[1000] = add 1.0:f64[] wn wp:f64[1000] = mul wl wo wq:f64[1000] = log1p wp wr:f64[1000] = add wg wj ws:f64[1000] = sub wr wq wt:f64[1000] = div ws ib wu:f64[1000] = exp wt wv:f64[1] = add ib 1.0:f64[] ww:f64[1000] = div wu wv wx:f64[1] = add ib 2.0:f64[] wy:f64[1000] = div wu wx wz:f64[1000] = add 1.0:f64[] wy xa:f64[1000] = mul ww wz xb:f64[1000] = log1p xa xc:f64[1000] = add wg wu xd:f64[1000] = sub xc xb xe:f64[1000] = div xd ib xf:f64[1000] = exp xe xg:f64[1] = add ib 1.0:f64[] xh:f64[1000] = div xf xg xi:f64[1] = add ib 2.0:f64[] xj:f64[1000] = div xf xi xk:f64[1] = add ib 3.0:f64[] xl:f64[1000] = div xf xk xm:f64[1000] = add 1.0:f64[] xl xn:f64[1000] = mul xj xm xo:f64[1000] = add 1.0:f64[] xn xp:f64[1000] = mul xh xo xq:f64[1000] = log1p xp xr:f64[1000] = add wg xf xs:f64[1000] = sub xr xq xt:f64[1000] = div xs ib xu:f64[1000] = exp xt xv:f64[1] = add ib 1.0:f64[] xw:f64[1] = mul 0.15:f64[] xv xx:bool[1000] = le qq xw xy:f64[1000] = pjit[name=_where jaxpr=_where] xx xu qq xz:bool[1000] = broadcast_in_dim[ broadcast_dimensions=() shape=(1000,) sharding=None ] False:bool[] ya:f64[1000] = broadcast_in_dim[ broadcast_dimensions=() shape=(1000,) sharding=None ] 1.0:f64[] yb:f64[1000] = broadcast_in_dim[ broadcast_dimensions=() shape=(1000,) sharding=None ] 1.0:f64[] _:bool[1000] _:f64[] _:f64[1000] yc:f64[1000] = while[ body_jaxpr={ lambda ; yd:f64[1000] ye:f64[1] yf:bool[1000] yg:f64[] yh:f64[1000] yi:f64[1000]. let yj:f64[1000] = mul yh yd yk:f64[1] = add ye yg yl:f64[1000] = div yj yk ym:f64[1000] = add yi yl yn:f64[1000] = pjit[name=_where jaxpr=_where4] yf yi ym yo:bool[1000] = lt yl 0.0001:f64[] yp:bool[] = gt yg 100.0:f64[] yq:bool[1000] = or yo yp yr:f64[] = add yg 1.0:f64[] in (yq, yr, yl, yn) } body_nconsts=2 cond_jaxpr={ lambda ; ys:bool[1000] yt:f64[] yu:f64[1000] yv:f64[1000]. let yw:bool[1000] = not ys yx:bool[] = reduce_or[axes=(0,)] yw in (yx,) } cond_nconsts=0 ] xy ib xz 1.0:f64[] ya yb yy:f64[1000] = log yc yz:f64[1000] = add wg xy za:f64[1000] = sub yz yy zb:f64[1000] = div za ib zc:f64[1000] = exp zb zd:f64[1] = add ib 1.0:f64[] ze:f64[1] = mul 0.01:f64[] zd zf:bool[1000] = le xy ze zg:f64[1] = add ib 1.0:f64[] zh:f64[1] = mul 0.7:f64[] zg zi:bool[1000] = gt xy zh zj:bool[1000] = or zf zi zk:f64[1000] = log zc zl:f64[1000] = mul ib zk zm:f64[1000] = sub zl zc zn:f64[1000] = sub zm wg zo:f64[1000] = add zn yy zp:f64[1000] = sub ib zc zq:f64[1000] = div zo zp zr:f64[1000] = sub 1.0:f64[] zq zs:f64[1000] = mul zc zr zt:f64[1000] = pjit[name=_where jaxpr=_where] zj xy zs zu:bool[1000] = le ic 0.5:f64[] zv:f64[1000] = pjit[name=_where jaxpr=_where] zu zt wc zw:bool[1] = lt ib 1.0:f64[] zx:f64[1000] = pjit[name=_where jaxpr=_where1] zw ow zv zy:bool[1] = eq ib 1.0:f64[] zz:f64[1000] = neg ig baa:f64[1000] = pjit[name=_where jaxpr=_where1] zy zz zx bab:f64[1000] = log baa bac:f64[1000] = mul ib bab bad:f64[1000] = sub bac baa bae:f64[1] = lgamma ib baf:f64[1000] = sub bad bae bag:f64[1000] = exp baf bah:bool[1000] = le ic 0.9:f64[] bai:bool[1000] = and bah True:bool[] baj:bool[1000] = gt id 0.9:f64[] bak:bool[1000] = and baj False:bool[] bal:bool[1000] = or bai bak bam:f64[1000] = igamma ib baa ban:f64[1000] = sub bam ic bao:f64[1000] = mul ban baa bap:f64[1000] = div bao bag baq:f64[1000] = igammac ib baa bar:f64[1000] = sub baq id bas:f64[1000] = neg bar bat:f64[1000] = mul bas baa bau:f64[1000] = div bat bag bav:f64[1000] = pjit[name=_where jaxpr=_where] bal bap bau baw:f64[1] = sub ib 1.0:f64[] bax:f64[1000] = div baw baa bay:f64[1000] = add -1.0:f64[] bax baz:bool[1000] = pjit[name=isinf jaxpr=isinf] bay bba:f64[1000] = sub baa bav bbb:f64[1000] = mul 0.5:f64[] bav bbc:f64[1000] = mul bbb bay bbd:f64[1000] = sub 1.0:f64[] bbc bbe:f64[1000] = div bav bbd bbf:f64[1000] = sub baa bbe bbg:f64[1000] = pjit[name=_where jaxpr=_where] baz bba bbf bbh:bool[1000] = eq bag 0.0:f64[] bbi:f64[1000] = pjit[name=_where jaxpr=_where] bbh baa bbg bbj:f64[1000] = log bbi bbk:f64[1000] = mul ib bbj bbl:f64[1000] = sub bbk bbi bbm:f64[1] = lgamma ib bbn:f64[1000] = sub bbl bbm bbo:f64[1000] = exp bbn bbp:bool[1000] = le ic 0.9:f64[] bbq:bool[1000] = and bbp True:bool[] bbr:bool[1000] = gt id 0.9:f64[] bbs:bool[1000] = and bbr False:bool[] bbt:bool[1000] = or bbq bbs bbu:f64[1000] = igamma ib bbi bbv:f64[1000] = sub bbu ic bbw:f64[1000] = mul bbv bbi bbx:f64[1000] = div bbw bbo bby:f64[1000] = igammac ib bbi bbz:f64[1000] = sub bby id bca:f64[1000] = neg bbz bcb:f64[1000] = mul bca bbi bcc:f64[1000] = div bcb bbo bcd:f64[1000] = pjit[name=_where jaxpr=_where] bbt bbx bcc bce:f64[1] = sub ib 1.0:f64[] bcf:f64[1000] = div bce bbi bcg:f64[1000] = add -1.0:f64[] bcf bch:bool[1000] = pjit[name=isinf jaxpr=isinf] bcg bci:f64[1000] = sub bbi bcd bcj:f64[1000] = mul 0.5:f64[] bcd bck:f64[1000] = mul bcj bcg bcl:f64[1000] = sub 1.0:f64[] bck bcm:f64[1000] = div bcd bcl bcn:f64[1000] = sub bbi bcm bco:f64[1000] = pjit[name=_where jaxpr=_where] bch bci bcn bcp:bool[1000] = eq bbo 0.0:f64[] bcq:f64[1000] = pjit[name=_where jaxpr=_where] bcp bbi bco bcr:f64[1000] = log bcq bcs:f64[1000] = mul ib bcr bct:f64[1000] = sub bcs bcq bcu:f64[1] = lgamma ib bcv:f64[1000] = sub bct bcu bcw:f64[1000] = exp bcv bcx:bool[1000] = le ic 0.9:f64[] bcy:bool[1000] = and bcx True:bool[] bcz:bool[1000] = gt id 0.9:f64[] bda:bool[1000] = and bcz False:bool[] bdb:bool[1000] = or bcy bda bdc:f64[1000] = igamma ib bcq bdd:f64[1000] = sub bdc ic bde:f64[1000] = mul bdd bcq bdf:f64[1000] = div bde bcw bdg:f64[1000] = igammac ib bcq bdh:f64[1000] = sub bdg id bdi:f64[1000] = neg bdh bdj:f64[1000] = mul bdi bcq bdk:f64[1000] = div bdj bcw bdl:f64[1000] = pjit[name=_where jaxpr=_where] bdb bdf bdk bdm:f64[1] = sub ib 1.0:f64[] bdn:f64[1000] = div bdm bcq bdo:f64[1000] = add -1.0:f64[] bdn bdp:bool[1000] = pjit[name=isinf jaxpr=isinf] bdo bdq:f64[1000] = sub bcq bdl bdr:f64[1000] = mul 0.5:f64[] bdl bds:f64[1000] = mul bdr bdo bdt:f64[1000] = sub 1.0:f64[] bds bdu:f64[1000] = div bdl bdt bdv:f64[1000] = sub bcq bdu bdw:f64[1000] = pjit[name=_where jaxpr=_where] bdp bdq bdv bdx:bool[1000] = eq bcw 0.0:f64[] bdy:f64[1000] = pjit[name=_where jaxpr=_where] bdx bcq bdw bdz:bool[1] = lt ib 0.0:f64[] bea:bool[1000] = lt ic 0.0:f64[] beb:bool[1000] = or bdz bea bec:bool[1000] = gt ic 1.0:f64[] bed:bool[1000] = or beb bec bee:f64[1000] = pjit[name=_where jaxpr=_where2] bed nan:f64[] bdy bef:bool[1000] = eq ic 0.0:f64[] beg:f64[1000] = pjit[name=_where jaxpr=_where2] bef 0.0:f64[] bee beh:bool[1000] = eq ic 1.0:f64[] bei:f64[1000] = pjit[name=_where jaxpr=_where2] beh inf:f64[] beg in (bei,) } jvp=_igammainv_jvp num_consts=2 symbolic_zeros=False ] de df hx hw bej:f64[1000] = mul 2.0:f64[] hy bek:f64[1,2,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 2) sharding=None ] dr bel:f64[1000] = div bej hs bem:f64[1000] = sqrt bel ben:f64[1000,1,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(1000, 1, 1) sharding=None ] bem beo:f64[1,2,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 2) sharding=None ] dr bep:f64[1000,2,2] = sub hk beo beq:f64[1000,2,2] = mul ben bep ber:f64[1000,2,2] = add bek beq bes:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 1.0:f64[] bet:f64[1] = mul 4.0:f64[] bes beu:f64[1000] = pjit[name=norm jaxpr=norm] hl bev:f64[1000] = integer_pow[y=2] beu bew:f64[1] = mul 0.5:f64[] bet bex:f64[1000] = mul 0.5:f64[] bev bey:f64[1000] = igamma bew bex bez:f64[1000] = sub 1.0:f64[] bey bfa:f64[1] = mul 0.5:f64[] bet bfb:f64[1000] = custom_jvp_call[ name=_igammainv_custom_gradient call_jaxpr={ lambda ; bfc:f64[4] bfd:f64[5] bfe:f64[1] bff:f64[1000]. let bfg:f64[1000] = sub 1.0:f64[] bff bfh:f64[1] = lgamma bfe bfi:f64[1000] = neg bff bfj:f64[1000] = log1p bfi bfk:f64[1000] = add bfj bfh bfl:f64[1000] = neg bfk bfm:f64[1] = sub bfe 1.0:f64[] bfn:f64[1000] = custom_jvp_call[ name=xlogy call_jaxpr={ lambda ; bfo:f64[1] bfp:f64[1000]. let bfq:bool[1] = ne bfo 0.0:f64[] bfr:f64[1000] = log bfp bfs:f64[1000] = mul bfo bfr bft:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] bfu:f64[1000] = pjit[name=_where jaxpr=_where3] bfq bfs bft in (bfu,) } jvp=_xlogy_jvp symbolic_zeros=False ] bfm bfl bfv:f64[1000] = square bfn bfw:f64[1000] = mul bfv bfn bfx:f64[1000] = square bfv bfy:f64[1] = square bfe bfz:f64[1] = mul bfy bfe bga:f64[1] = sub bfe 1.0:f64[] bgb:f64[1000] = add 1.0:f64[] bfn bgc:f64[1000] = mul bga bgb bgd:f64[1] = sub bfe 1.0:f64[] bge:f64[1] = mul 3.0:f64[] bfe bgf:f64[1] = sub bge 5.0:f64[] bgg:f64[1] = div bgf 2.0:f64[] bgh:f64[1] = sub bfe 2.0:f64[] bgi:f64[1000] = div bfn 2.0:f64[] bgj:f64[1000] = sub bgh bgi bgk:f64[1000] = mul bfn bgj bgl:f64[1000] = add bgg bgk bgm:f64[1000] = mul bgd bgl bgn:f64[1] = sub bfe 1.0:f64[] bgo:f64[1000] = div bfw 3.0:f64[] bgp:f64[1] = mul 3.0:f64[] bfe bgq:f64[1] = sub bgp 5.0:f64[] bgr:f64[1000] = mul bgq bfv bgs:f64[1000] = div bgr 2.0:f64[] bgt:f64[1000] = sub bgo bgs bgu:f64[1] = mul 6.0:f64[] bfe bgv:f64[1] = sub bfy bgu bgw:f64[1] = add bgv 7.0:f64[] bgx:f64[1000] = mul bgw bfn bgy:f64[1000] = add bgt bgx bgz:f64[1] = mul 11.0:f64[] bfy bha:f64[1] = mul 46.0:f64[] bfe bhb:f64[1] = sub bgz bha bhc:f64[1] = add bhb 47.0:f64[] bhd:f64[1] = div bhc 6.0:f64[] bhe:f64[1000] = add bgy bhd bhf:f64[1000] = mul bgn bhe bhg:f64[1] = sub bfe 1.0:f64[] bhh:f64[1000] = neg bfx bhi:f64[1000] = div bhh 4.0:f64[] bhj:f64[1] = mul 11.0:f64[] bfe bhk:f64[1] = sub bhj 17.0:f64[] bhl:f64[1000] = mul bhk bfw bhm:f64[1000] = div bhl 6.0:f64[] bhn:f64[1000] = add bhi bhm bho:f64[1] = mul -3.0:f64[] bfy bhp:f64[1] = mul 13.0:f64[] bfe bhq:f64[1] = add bho bhp bhr:f64[1] = sub bhq 13.0:f64[] bhs:f64[1000] = mul bhr bfv bht:f64[1000] = add bhn bhs bhu:f64[1] = mul 2.0:f64[] bfz bhv:f64[1] = mul 25.0:f64[] bfy bhw:f64[1] = sub bhu bhv bhx:f64[1] = mul 72.0:f64[] bfe bhy:f64[1] = add bhw bhx bhz:f64[1] = sub bhy 61.0:f64[] bia:f64[1000] = mul bhz bfn bib:f64[1000] = div bia 2.0:f64[] bic:f64[1000] = add bht bib bid:f64[1] = mul 25.0:f64[] bfz bie:f64[1] = mul 195.0:f64[] bfy bif:f64[1] = sub bid bie big:f64[1] = mul 477.0:f64[] bfe bih:f64[1] = add bif big bii:f64[1] = sub bih 379.0:f64[] bij:f64[1] = div bii 12.0:f64[] bik:f64[1000] = add bic bij bil:f64[1000] = mul bhg bik bim:f64[1000] = add bfl bfn bin:f64[1000] = div bil bfl bio:f64[1000] = add bin bhf bip:f64[1000] = div bio bfl biq:f64[1000] = div bgm bfl bir:f64[1000] = add bip biq bis:f64[1000] = add bir bgc bit:f64[1000] = div bis bfl biu:f64[1000] = add bim bit biv:f64[1000] = neg bfk biw:f64[1] = sub 1.0:f64[] bfe bix:f64[1000] = neg bfk biy:f64[1000] = log bix biz:f64[1000] = mul biw biy bja:f64[1000] = sub biv biz bjb:f64[1000] = square bja bjc:bool[1000] = gt bfk -4.605170185988091:f64[] bjd:f64[1000] = neg bfk bje:f64[1] = sub 1.0:f64[] bfe bjf:f64[1000] = log bja bjg:f64[1000] = mul bje bjf bjh:f64[1000] = sub bjd bjg bji:f64[1] = sub 3.0:f64[] bfe bjj:f64[1] = mul 2.0:f64[] bji bjk:f64[1000] = mul bjj bja bjl:f64[1000] = add bjb bjk bjm:f64[1] = sub 2.0:f64[] bfe bjn:f64[1] = sub 3.0:f64[] bfe bjo:f64[1] = mul bjm bjn bjp:f64[1000] = add bjl bjo bjq:f64[1] = sub 5.0:f64[] bfe bjr:f64[1000] = mul bjq bja bjs:f64[1000] = add bjb bjr bjt:f64[1000] = add bjs 2.0:f64[] bju:f64[1000] = div bjp bjt bjv:f64[1000] = log bju bjw:f64[1000] = sub bjh bjv bjx:f64[1000] = pjit[name=_where jaxpr=_where] bjc bjw biu bjy:bool[1000] = ge bfk -1.8971199848858813:f64[] bjz:f64[1000] = neg bfk bka:f64[1] = sub bfe 1.0:f64[] bkb:f64[1000] = custom_jvp_call[ name=xlogy call_jaxpr={ lambda ; bkc:f64[1] bkd:f64[1000]. let bke:bool[1] = ne bkc 0.0:f64[] bkf:f64[1000] = log bkd bkg:f64[1000] = mul bkc bkf bkh:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] bki:f64[1000] = pjit[name=_where jaxpr=_where3] bke bkg bkh in (bki,) } jvp=_xlogy_jvp symbolic_zeros=False ] bka bja bkj:f64[1000] = add bjz bkb bkk:f64[1] = sub 1.0:f64[] bfe bkl:f64[1000] = add 1.0:f64[] bja bkm:f64[1000] = div bkk bkl bkn:f64[1000] = log1p bkm bko:f64[1000] = sub bkj bkn bkp:f64[1000] = pjit[name=_where jaxpr=_where] bjy bko bjx bkq:f64[1000] = exp bfk bkr:f64[1000] = sub -0.5772156649015329:f64[] bkq bks:f64[1000] = exp bkr bkt:f64[1000] = exp bks bku:f64[1000] = mul bks bkt bkv:bool[1] = lt bfe 0.3:f64[] bkw:bool[1000] = ge bfk -1.0498221244986778:f64[] bkx:bool[1000] = and bkv bkw bky:f64[1000] = exp bku bkz:f64[1000] = mul bks bky bla:f64[1000] = pjit[name=_where jaxpr=_where] bkx bkz bkp blb:f64[1000] = exp bfk blc:f64[1000] = mul blb bfg bld:bool[1000] = gt blc 1e-08:f64[] ble:bool[1000] = gt bfg 1e-05:f64[] blf:bool[1000] = and bld ble blg:f64[1] = exp bfh blh:f64[1000] = mul bff blg bli:f64[1000] = mul blh bfe blj:f64[1] = integer_pow[y=-1] bfe blk:f64[1000] = pow bli blj bll:f64[1000] = neg bfg blm:f64[1000] = div bll bfe bln:f64[1000] = sub blm 0.5772156649015329:f64[] blo:f64[1000] = exp bln blp:f64[1000] = pjit[name=_where jaxpr=_where] blf blk blo blq:bool[1000] = gt bfk -0.5108256237659907:f64[] blr:bool[1000] = ge bfk -0.7985076962177716:f64[] bls:bool[1] = ge bfe 0.3:f64[] blt:bool[1000] = and blr bls blu:bool[1000] = or blq blt blv:f64[1] = add bfe 1.0:f64[] blw:f64[1000] = div blp blv blx:f64[1000] = sub 1.0:f64[] blw bly:f64[1000] = div blp blx blz:f64[1000] = pjit[name=_where jaxpr=_where] blu bly bla bma:f64[1] = sqrt bfe bmb:bool[1000] = lt bff 0.5:f64[] bmc:f64[1000] = log bff bmd:f64[1000] = mul -2.0:f64[] bmc bme:f64[1000] = sqrt bmd bmf:f64[1000] = log bfg bmg:f64[1000] = mul -2.0:f64[] bmf bmh:f64[1000] = sqrt bmg bmi:f64[1000] = pjit[name=_where jaxpr=_where] bmb bme bmh bmj:f64[1000] = pjit[name=polyval jaxpr=polyval] bfc bmi bmk:f64[1000] = pjit[name=polyval jaxpr=polyval1] bfd bmi bml:f64[1000] = div bmj bmk bmm:f64[1000] = sub bmi bml bmn:bool[1000] = lt bff 0.5:f64[] bmo:f64[1000] = neg bmm bmp:f64[1000] = pjit[name=_where jaxpr=_where] bmn bmo bmm bmq:f64[1000] = square bmp bmr:f64[1000] = mul bmq bmp bms:f64[1000] = square bmq bmt:f64[1000] = mul bms bmp bmu:f64[1000] = mul bmp bma bmv:f64[1000] = add bfe bmu bmw:f64[1000] = sub bmq 1.0:f64[] bmx:f64[1000] = div bmw 3.0:f64[] bmy:f64[1000] = add bmv bmx bmz:f64[1000] = mul 7.0:f64[] bmp bna:f64[1000] = sub bmr bmz bnb:f64[1] = mul 36.0:f64[] bma bnc:f64[1000] = div bna bnb bnd:f64[1000] = add bmy bnc bne:f64[1000] = mul 3.0:f64[] bms bnf:f64[1000] = mul 7.0:f64[] bmq bng:f64[1000] = add bne bnf bnh:f64[1000] = sub bng 16.0:f64[] bni:f64[1] = mul 810.0:f64[] bfe bnj:f64[1000] = div bnh bni bnk:f64[1000] = sub bnd bnj bnl:f64[1000] = mul 9.0:f64[] bmt bnm:f64[1000] = mul 256.0:f64[] bmr bnn:f64[1000] = add bnl bnm bno:f64[1000] = mul 433.0:f64[] bmp bnp:f64[1000] = sub bnn bno bnq:f64[1] = mul 38880.0:f64[] bfe bnr:f64[1] = mul bnq bma bns:f64[1000] = div bnp bnr bnt:f64[1000] = add bnk bns bnu:f64[1] = sub bfe 1.0:f64[] bnv:f64[1] = mul bfe bnu bnw:f64[1] = copy bnv bnx:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 2.0:f64[] bny:f64[1] = custom_jvp_call[ name=_maximum_ call_jaxpr={ lambda ; bnz:f64[1] boa:f64[1]. let bob:f64[1] = max bnz boa in (bob,) } jvp=_maximum_jvp symbolic_zeros=False ] bnx bnw boc:f64[1] = neg bny bod:f64[1] = mul boc 2.302585092994046:f64[] boe:bool[1000] = le bfk bod bof:f64[1000] = neg bfk bog:f64[1] = sub bfe 1.0:f64[] boh:f64[1000] = custom_jvp_call[ name=xlogy call_jaxpr={ lambda ; boi:f64[1] boj:f64[1000]. let bok:bool[1] = ne boi 0.0:f64[] bol:f64[1000] = log boj bom:f64[1000] = mul boi bol bon:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] boo:f64[1000] = pjit[name=_where jaxpr=_where3] bok bom bon in (boo,) } jvp=_xlogy_jvp symbolic_zeros=False ] bog bof bop:f64[1000] = square boh boq:f64[1000] = mul bop boh bor:f64[1000] = square bop bos:f64[1] = square bfe bot:f64[1] = mul bos bfe bou:f64[1] = sub bfe 1.0:f64[] bov:f64[1000] = add 1.0:f64[] boh bow:f64[1000] = mul bou bov box:f64[1] = sub bfe 1.0:f64[] boy:f64[1] = mul 3.0:f64[] bfe boz:f64[1] = sub boy 5.0:f64[] bpa:f64[1] = div boz 2.0:f64[] bpb:f64[1] = sub bfe 2.0:f64[] bpc:f64[1000] = div boh 2.0:f64[] bpd:f64[1000] = sub bpb bpc bpe:f64[1000] = mul boh bpd bpf:f64[1000] = add bpa bpe bpg:f64[1000] = mul box bpf bph:f64[1] = sub bfe 1.0:f64[] bpi:f64[1000] = div boq 3.0:f64[] bpj:f64[1] = mul 3.0:f64[] bfe bpk:f64[1] = sub bpj 5.0:f64[] bpl:f64[1000] = mul bpk bop bpm:f64[1000] = div bpl 2.0:f64[] bpn:f64[1000] = sub bpi bpm bpo:f64[1] = mul 6.0:f64[] bfe bpp:f64[1] = sub bos bpo bpq:f64[1] = add bpp 7.0:f64[] bpr:f64[1000] = mul bpq boh bps:f64[1000] = add bpn bpr bpt:f64[1] = mul 11.0:f64[] bos bpu:f64[1] = mul 46.0:f64[] bfe bpv:f64[1] = sub bpt bpu bpw:f64[1] = add bpv 47.0:f64[] bpx:f64[1] = div bpw 6.0:f64[] bpy:f64[1000] = add bps bpx bpz:f64[1000] = mul bph bpy bqa:f64[1] = sub bfe 1.0:f64[] bqb:f64[1000] = neg bor bqc:f64[1000] = div bqb 4.0:f64[] bqd:f64[1] = mul 11.0:f64[] bfe bqe:f64[1] = sub bqd 17.0:f64[] bqf:f64[1000] = mul bqe boq bqg:f64[1000] = div bqf 6.0:f64[] bqh:f64[1000] = add bqc bqg bqi:f64[1] = mul -3.0:f64[] bos bqj:f64[1] = mul 13.0:f64[] bfe bqk:f64[1] = add bqi bqj bql:f64[1] = sub bqk 13.0:f64[] bqm:f64[1000] = mul bql bop bqn:f64[1000] = add bqh bqm bqo:f64[1] = mul 2.0:f64[] bot bqp:f64[1] = mul 25.0:f64[] bos bqq:f64[1] = sub bqo bqp bqr:f64[1] = mul 72.0:f64[] bfe bqs:f64[1] = add bqq bqr bqt:f64[1] = sub bqs 61.0:f64[] bqu:f64[1000] = mul bqt boh bqv:f64[1000] = div bqu 2.0:f64[] bqw:f64[1000] = add bqn bqv bqx:f64[1] = mul 25.0:f64[] bot bqy:f64[1] = mul 195.0:f64[] bos bqz:f64[1] = sub bqx bqy bra:f64[1] = mul 477.0:f64[] bfe brb:f64[1] = add bqz bra brc:f64[1] = sub brb 379.0:f64[] brd:f64[1] = div brc 12.0:f64[] bre:f64[1000] = add bqw brd brf:f64[1000] = mul bqa bre brg:f64[1000] = add bof boh brh:f64[1000] = div brf bof bri:f64[1000] = add brh bpz brj:f64[1000] = div bri bof brk:f64[1000] = div bpg bof brl:f64[1000] = add brj brk brm:f64[1000] = add brl bow brn:f64[1000] = div brm bof bro:f64[1000] = add brg brn brp:f64[1000] = neg bfk brq:f64[1] = sub bfe 1.0:f64[] brr:f64[1000] = custom_jvp_call[ name=xlogy call_jaxpr={ lambda ; brs:f64[1] brt:f64[1000]. let bru:bool[1] = ne brs 0.0:f64[] brv:f64[1000] = log brt brw:f64[1000] = mul brs brv brx:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] bry:f64[1000] = pjit[name=_where jaxpr=_where3] bru brw brx in (bry,) } jvp=_xlogy_jvp symbolic_zeros=False ] brq bnt brz:f64[1000] = add brp brr bsa:f64[1] = sub 1.0:f64[] bfe bsb:f64[1000] = add 1.0:f64[] bnt bsc:f64[1000] = div bsa bsb bsd:f64[1000] = log1p bsc bse:f64[1000] = sub brz bsd bsf:f64[1000] = neg bfk bsg:f64[1] = sub bfe 1.0:f64[] bsh:f64[1000] = custom_jvp_call[ name=xlogy call_jaxpr={ lambda ; bsi:f64[1] bsj:f64[1000]. let bsk:bool[1] = ne bsi 0.0:f64[] bsl:f64[1000] = log bsj bsm:f64[1000] = mul bsi bsl bsn:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] bso:f64[1000] = pjit[name=_where jaxpr=_where3] bsk bsm bsn in (bso,) } jvp=_xlogy_jvp symbolic_zeros=False ] bsg bse bsp:f64[1000] = add bsf bsh bsq:f64[1] = sub 1.0:f64[] bfe bsr:f64[1000] = add 1.0:f64[] bse bss:f64[1000] = div bsq bsr bst:f64[1000] = log1p bss bsu:f64[1000] = sub bsp bst bsv:f64[1000] = pjit[name=_where jaxpr=_where] boe bro bsu bsw:f64[1] = mul 3.0:f64[] bfe bsx:bool[1000] = lt bnt bsw bsy:f64[1000] = pjit[name=_where jaxpr=_where] bsx bnt bsv bsz:bool[1] = ge bfe 500.0:f64[] bta:f64[1000] = div bnt bfe btb:f64[1000] = sub 1.0:f64[] bta btc:f64[1000] = abs btb btd:bool[1000] = lt btc 1e-06:f64[] bte:bool[1000] = and bsz btd btf:f64[1000] = pjit[name=_where jaxpr=_where] bte bnt bsy btg:f64[1000] = log bff bth:f64[1] = add bfe 1.0:f64[] bti:f64[1] = lgamma bth btj:f64[1000] = add btg bti btk:f64[1000] = add btj bnt btl:f64[1000] = div btk bfe btm:f64[1000] = exp btl btn:f64[1] = add bfe 1.0:f64[] bto:f64[1000] = div btm btn btp:f64[1] = add bfe 2.0:f64[] btq:f64[1000] = div btm btp btr:f64[1000] = add 1.0:f64[] btq bts:f64[1000] = mul bto btr btt:f64[1000] = log1p bts btu:f64[1000] = add btj btm btv:f64[1000] = sub btu btt btw:f64[1000] = div btv bfe btx:f64[1000] = exp btw bty:f64[1] = add bfe 1.0:f64[] btz:f64[1000] = div btx bty bua:f64[1] = add bfe 2.0:f64[] bub:f64[1000] = div btx bua buc:f64[1000] = add 1.0:f64[] bub bud:f64[1000] = mul btz buc bue:f64[1000] = log1p bud buf:f64[1000] = add btj btx bug:f64[1000] = sub buf bue buh:f64[1000] = div bug bfe bui:f64[1000] = exp buh buj:f64[1] = add bfe 1.0:f64[] buk:f64[1000] = div bui buj bul:f64[1] = add bfe 2.0:f64[] bum:f64[1000] = div bui bul bun:f64[1] = add bfe 3.0:f64[] buo:f64[1000] = div bui bun bup:f64[1000] = add 1.0:f64[] buo buq:f64[1000] = mul bum bup bur:f64[1000] = add 1.0:f64[] buq bus:f64[1000] = mul buk bur but:f64[1000] = log1p bus buu:f64[1000] = add btj bui buv:f64[1000] = sub buu but buw:f64[1000] = div buv bfe bux:f64[1000] = exp buw buy:f64[1] = add bfe 1.0:f64[] buz:f64[1] = mul 0.15:f64[] buy bva:bool[1000] = le bnt buz bvb:f64[1000] = pjit[name=_where jaxpr=_where] bva bux bnt bvc:bool[1000] = broadcast_in_dim[ broadcast_dimensions=() shape=(1000,) sharding=None ] False:bool[] bvd:f64[1000] = broadcast_in_dim[ broadcast_dimensions=() shape=(1000,) sharding=None ] 1.0:f64[] bve:f64[1000] = broadcast_in_dim[ broadcast_dimensions=() shape=(1000,) sharding=None ] 1.0:f64[] _:bool[1000] _:f64[] _:f64[1000] bvf:f64[1000] = while[ body_jaxpr={ lambda ; bvg:f64[1000] bvh:f64[1] bvi:bool[1000] bvj:f64[] bvk:f64[1000] bvl:f64[1000]. let bvm:f64[1000] = mul bvk bvg bvn:f64[1] = add bvh bvj bvo:f64[1000] = div bvm bvn bvp:f64[1000] = add bvl bvo bvq:f64[1000] = pjit[name=_where jaxpr=_where4] bvi bvl bvp bvr:bool[1000] = lt bvo 0.0001:f64[] bvs:bool[] = gt bvj 100.0:f64[] bvt:bool[1000] = or bvr bvs bvu:f64[] = add bvj 1.0:f64[] in (bvt, bvu, bvo, bvq) } body_nconsts=2 cond_jaxpr={ lambda ; bvv:bool[1000] bvw:f64[] bvx:f64[1000] bvy:f64[1000]. let bvz:bool[1000] = not bvv bwa:bool[] = reduce_or[axes=(0,)] bvz in (bwa,) } cond_nconsts=0 ] bvb bfe bvc 1.0:f64[] bvd bve bwb:f64[1000] = log bvf bwc:f64[1000] = add btj bvb bwd:f64[1000] = sub bwc bwb bwe:f64[1000] = div bwd bfe bwf:f64[1000] = exp bwe bwg:f64[1] = add bfe 1.0:f64[] bwh:f64[1] = mul 0.01:f64[] bwg bwi:bool[1000] = le bvb bwh bwj:f64[1] = add bfe 1.0:f64[] bwk:f64[1] = mul 0.7:f64[] bwj bwl:bool[1000] = gt bvb bwk bwm:bool[1000] = or bwi bwl bwn:f64[1000] = log bwf bwo:f64[1000] = mul bfe bwn bwp:f64[1000] = sub bwo bwf bwq:f64[1000] = sub bwp btj bwr:f64[1000] = add bwq bwb bws:f64[1000] = sub bfe bwf bwt:f64[1000] = div bwr bws bwu:f64[1000] = sub 1.0:f64[] bwt bwv:f64[1000] = mul bwf bwu bww:f64[1000] = pjit[name=_where jaxpr=_where] bwm bvb bwv bwx:bool[1000] = le bff 0.5:f64[] bwy:f64[1000] = pjit[name=_where jaxpr=_where] bwx bww btf bwz:bool[1] = lt bfe 1.0:f64[] bxa:f64[1000] = pjit[name=_where jaxpr=_where1] bwz blz bwy bxb:bool[1] = eq bfe 1.0:f64[] bxc:f64[1000] = neg bfj bxd:f64[1000] = pjit[name=_where jaxpr=_where1] bxb bxc bxa bxe:f64[1000] = log bxd bxf:f64[1000] = mul bfe bxe bxg:f64[1000] = sub bxf bxd bxh:f64[1] = lgamma bfe bxi:f64[1000] = sub bxg bxh bxj:f64[1000] = exp bxi bxk:bool[1000] = le bff 0.9:f64[] bxl:bool[1000] = and bxk True:bool[] bxm:bool[1000] = gt bfg 0.9:f64[] bxn:bool[1000] = and bxm False:bool[] bxo:bool[1000] = or bxl bxn bxp:f64[1000] = igamma bfe bxd bxq:f64[1000] = sub bxp bff bxr:f64[1000] = mul bxq bxd bxs:f64[1000] = div bxr bxj bxt:f64[1000] = igammac bfe bxd bxu:f64[1000] = sub bxt bfg bxv:f64[1000] = neg bxu bxw:f64[1000] = mul bxv bxd bxx:f64[1000] = div bxw bxj bxy:f64[1000] = pjit[name=_where jaxpr=_where] bxo bxs bxx bxz:f64[1] = sub bfe 1.0:f64[] bya:f64[1000] = div bxz bxd byb:f64[1000] = add -1.0:f64[] bya byc:bool[1000] = pjit[name=isinf jaxpr=isinf] byb byd:f64[1000] = sub bxd bxy bye:f64[1000] = mul 0.5:f64[] bxy byf:f64[1000] = mul bye byb byg:f64[1000] = sub 1.0:f64[] byf byh:f64[1000] = div bxy byg byi:f64[1000] = sub bxd byh byj:f64[1000] = pjit[name=_where jaxpr=_where] byc byd byi byk:bool[1000] = eq bxj 0.0:f64[] byl:f64[1000] = pjit[name=_where jaxpr=_where] byk bxd byj bym:f64[1000] = log byl byn:f64[1000] = mul bfe bym byo:f64[1000] = sub byn byl byp:f64[1] = lgamma bfe byq:f64[1000] = sub byo byp byr:f64[1000] = exp byq bys:bool[1000] = le bff 0.9:f64[] byt:bool[1000] = and bys True:bool[] byu:bool[1000] = gt bfg 0.9:f64[] byv:bool[1000] = and byu False:bool[] byw:bool[1000] = or byt byv byx:f64[1000] = igamma bfe byl byy:f64[1000] = sub byx bff byz:f64[1000] = mul byy byl bza:f64[1000] = div byz byr bzb:f64[1000] = igammac bfe byl bzc:f64[1000] = sub bzb bfg bzd:f64[1000] = neg bzc bze:f64[1000] = mul bzd byl bzf:f64[1000] = div bze byr bzg:f64[1000] = pjit[name=_where jaxpr=_where] byw bza bzf bzh:f64[1] = sub bfe 1.0:f64[] bzi:f64[1000] = div bzh byl bzj:f64[1000] = add -1.0:f64[] bzi bzk:bool[1000] = pjit[name=isinf jaxpr=isinf] bzj bzl:f64[1000] = sub byl bzg bzm:f64[1000] = mul 0.5:f64[] bzg bzn:f64[1000] = mul bzm bzj bzo:f64[1000] = sub 1.0:f64[] bzn bzp:f64[1000] = div bzg bzo bzq:f64[1000] = sub byl bzp bzr:f64[1000] = pjit[name=_where jaxpr=_where] bzk bzl bzq bzs:bool[1000] = eq byr 0.0:f64[] bzt:f64[1000] = pjit[name=_where jaxpr=_where] bzs byl bzr bzu:f64[1000] = log bzt bzv:f64[1000] = mul bfe bzu bzw:f64[1000] = sub bzv bzt bzx:f64[1] = lgamma bfe bzy:f64[1000] = sub bzw bzx bzz:f64[1000] = exp bzy caa:bool[1000] = le bff 0.9:f64[] cab:bool[1000] = and caa True:bool[] cac:bool[1000] = gt bfg 0.9:f64[] cad:bool[1000] = and cac False:bool[] cae:bool[1000] = or cab cad caf:f64[1000] = igamma bfe bzt cag:f64[1000] = sub caf bff cah:f64[1000] = mul cag bzt cai:f64[1000] = div cah bzz caj:f64[1000] = igammac bfe bzt cak:f64[1000] = sub caj bfg cal:f64[1000] = neg cak cam:f64[1000] = mul cal bzt can:f64[1000] = div cam bzz cao:f64[1000] = pjit[name=_where jaxpr=_where] cae cai can cap:f64[1] = sub bfe 1.0:f64[] caq:f64[1000] = div cap bzt car:f64[1000] = add -1.0:f64[] caq cas:bool[1000] = pjit[name=isinf jaxpr=isinf] car cat:f64[1000] = sub bzt cao cau:f64[1000] = mul 0.5:f64[] cao cav:f64[1000] = mul cau car caw:f64[1000] = sub 1.0:f64[] cav cax:f64[1000] = div cao caw cay:f64[1000] = sub bzt cax caz:f64[1000] = pjit[name=_where jaxpr=_where] cas cat cay cba:bool[1000] = eq bzz 0.0:f64[] cbb:f64[1000] = pjit[name=_where jaxpr=_where] cba bzt caz cbc:bool[1] = lt bfe 0.0:f64[] cbd:bool[1000] = lt bff 0.0:f64[] cbe:bool[1000] = or cbc cbd cbf:bool[1000] = gt bff 1.0:f64[] cbg:bool[1000] = or cbe cbf cbh:f64[1000] = pjit[name=_where jaxpr=_where2] cbg nan:f64[] cbb cbi:bool[1000] = eq bff 0.0:f64[] cbj:f64[1000] = pjit[name=_where jaxpr=_where2] cbi 0.0:f64[] cbh cbk:bool[1000] = eq bff 1.0:f64[] cbl:f64[1000] = pjit[name=_where jaxpr=_where2] cbk inf:f64[] cbj in (cbl,) } jvp=_igammainv_jvp num_consts=2 symbolic_zeros=False ] dg dh bfa bez cbm:f64[1000] = mul 2.0:f64[] bfb cbn:f64[1,2,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 2) sharding=None ] dr cbo:f64[1000] = div cbm bev cbp:f64[1000] = sqrt cbo cbq:f64[1000,1,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(1000, 1, 1) sharding=None ] cbp cbr:f64[1,2,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 2) sharding=None ] dr cbs:f64[1000,2,2] = sub ho cbr cbt:f64[1000,2,2] = mul cbq cbs cbu:f64[1000,2,2] = add cbn cbt cbv:f64[4000,2,2] = concatenate[dimension=0] hk ho ber cbu cbw:f64[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=None] di cbx:f64[2] = squeeze[dimensions=(0,)] cbw cby:f64[2] cbz:f64[2,2] = pjit[ name=eigh jaxpr={ lambda ; dj:f64[2,2]. let cca:f64[2,2] = transpose[permutation=(1, 0)] dj ccb:f64[2,2] = add dj cca ccc:f64[2,2] = div ccb 2.0:f64[] cbz:f64[2,2] cby:f64[2] = eigh[ lower=True sort_eigenvalues=True subset_by_index=None ] ccc in (cby, cbz) } ] dj ccd:f64[2] = abs cby cce:f64[2] = sqrt ccd ccf:f64[2,2] = dot_general[ dimension_numbers=(([], []), ([1], [0])) preferred_element_type=float64 ] cbz cce ccg:f64[2,2] = pjit[ name=qr jaxpr={ lambda ; ccf:f64[2,2]. let cch:f64[2,2] ccg:f64[2,2] = qr[ full_matrices=True pivoting=False use_magma=None ] ccf in (ccg,) } ] ccf cci:u32[2,2] = iota[dimension=0 dtype=uint32 shape=(2, 2) sharding=None] ccj:u32[2,2] = iota[dimension=1 dtype=uint32 shape=(2, 2) sharding=None] cck:bool[2,2] = eq cci ccj ccl:f64[2,2] = broadcast_in_dim[ broadcast_dimensions=() shape=(2, 2) sharding=None ] 0.0:f64[] ccm:f64[2,2] = select_n cck ccl ccg ccn:f64[2] = reduce_sum[axes=(0,)] ccm cco:f64[2,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2, 1) sharding=None ] ccn ccp:f64[2,1] = sign cco ccq:f64[2,2] = mul ccg ccp ccr:f64[2,2] = transpose[permutation=(1, 0)] ccq ccs:f64[4000,1,2] = slice[ limit_indices=(4000, 1, 2) start_indices=(0, 0, 0) strides=None ] cbv cct:f64[4000,2] = squeeze[dimensions=(1,)] ccs ccu:f64[1,2] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 2) sharding=None ] cbx ccv:f64[4000,2] = sub cct ccu ccw:f64[4000,2,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1) shape=(4000, 2, 1) sharding=None ] ccv ccx:f64[2,2] = pjit[name=tril jaxpr=tril] ccr ccy:f64[2,1,4000] = pjit[ name=_solve_triangular jaxpr={ lambda ; ccx:f64[2,2] ccw:f64[4000,2,1]. let ccz:f64[2,1,4000] = transpose[permutation=(1, 2, 0)] ccw cda:f64[2,4000] = reshape[ dimensions=None new_sizes=(2, 4000) sharding=None ] ccz cdb:f64[2,4000] = triangular_solve[ conjugate_a=False left_side=True lower=True transpose_a=False unit_diagonal=False ] ccx cda ccy:f64[2,1,4000] = reshape[ dimensions=None new_sizes=(2, 1, 4000) sharding=None ] cdb in (ccy,) } ] ccx ccw cdc:f64[4000,2,1] = transpose[permutation=(2, 0, 1)] ccy cdd:f64[4000,2] = squeeze[dimensions=(2,)] cdc cde:f64[4000,2] = div cdd 1.0:f64[] cdf:f64[] = div 0.0:f64[] 1.0:f64[] cdg:f64[4000,2] = sub cde cdf cdh:f64[4000,2] = square cdg cdi:f64[4000,2] = mul -0.5:f64[] cdh cdj:f64[] = log 1.0:f64[] cdk:f64[] = add 0.9189385332046727:f64[] cdj cdl:f64[4000,2] = sub cdi cdk cdm:f64[4000] = reduce_sum[axes=(1,)] cdl cdn:f64[2] = pjit[name=diagonal jaxpr=diagonal] ccr cdo:f64[2] = abs cdn cdp:f64[2] = log cdo cdq:f64[] = reduce_sum[axes=(0,)] cdp cdr:f64[] = neg cdq cds:f64[] = neg cdr cdt:f64[] = mul 1.0:f64[] cds cdu:f64[] = reduce_sum[axes=()] cdt cdv:f64[] = add 0.0:f64[] cdu cdw:f64[] = neg 0.0:f64[] cdx:f64[] = neg cdw cdy:f64[2] = broadcast_in_dim[ broadcast_dimensions=() shape=(2,) sharding=None ] 1.0:f64[] cdz:f64[2] = mul cdy cdx cea:f64[] = reduce_sum[axes=(0,)] cdz ceb:f64[] = add cdv cea cec:f64[] = copy ceb ced:f64[4000] = sub cdm cec cee:f64[4000,1,2] = slice[ limit_indices=(4000, 1, 2) start_indices=(0, 0, 0) strides=None ] cbv cef:f64[1,1,2] = transpose[permutation=(0, 2, 1)] dk ceg:f64[4000,1,2] = slice[ limit_indices=(4000, 2, 2) start_indices=(0, 1, 0) strides=None ] cbv ceh:f64[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=None] di cei:f64[1,1,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 1, 2) sharding=None ] ceh cej:f64[4000,1,2] = sub ceg cei cek:f64[4000,1,2,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1, 2) shape=(4000, 1, 2, 1) sharding=None ] cee cel:f64[2,2] = squeeze[dimensions=(0,)] dl cem:f64[2,4000,1,1] = dot_general[ dimension_numbers=(([1], [2]), ([], [])) preferred_element_type=float64 ] cel cek cen:f64[4000,1,2,1] = transpose[permutation=(1, 2, 0, 3)] cem ceo:f64[4000,1,2,1] = slice[ limit_indices=(4000, 1, 2, 1) start_indices=(0, 0, 0, 0) strides=None ] cen cep:f64[4000,1,2] = squeeze[dimensions=(3,)] ceo ceq:f64[4000,1,2] = sub cej cep cer:f64[1,1,4000] = dot_general[ dimension_numbers=(([2], [2]), ([0], [1])) preferred_element_type=float64 ] cef ceq ces:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] cet:f64[1,1] ceu:f64[1,1,1] = pjit[ name=eigh jaxpr={ lambda ; dm:f64[1,1,1]. let cev:f64[1,1,1] = transpose[permutation=(0, 2, 1)] dm cew:f64[1,1,1] = add dm cev cex:f64[1,1,1] = div cew 2.0:f64[] ceu:f64[1,1,1] cet:f64[1,1] = eigh[ lower=True sort_eigenvalues=True subset_by_index=None ] cex in (cet, ceu) } ] dm cey:f64[1,1] = abs cet cez:f64[1,1] = sqrt cey cfa:f64[1,1,1] = dot_general[ dimension_numbers=(([], []), ([0, 2], [0, 1])) preferred_element_type=float64 ] ceu cez cfb:f64[1,1,1] = pjit[ name=qr jaxpr={ lambda ; cfa:f64[1,1,1]. let cfc:f64[1,1,1] cfb:f64[1,1,1] = qr[ full_matrices=True pivoting=False use_magma=None ] cfa in (cfb,) } ] cfa cfd:u32[1,1] = iota[dimension=0 dtype=uint32 shape=(1, 1) sharding=None] cfe:u32[1,1] = iota[dimension=1 dtype=uint32 shape=(1, 1) sharding=None] cff:bool[1,1] = eq cfd cfe cfg:bool[1,1,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 1, 1) sharding=None ] cff cfh:f64[1,1,1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1, 1, 1) sharding=None ] 0.0:f64[] cfi:f64[1,1,1] = select_n cfg cfh cfb cfj:f64[1,1] = reduce_sum[axes=(1,)] cfi cfk:f64[1,1,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1) shape=(1, 1, 1) sharding=None ] cfj cfl:f64[1,1,1] = sign cfk cfm:f64[1,1,1] = mul cfb cfl cfn:f64[1,1,1] = transpose[permutation=(0, 2, 1)] cfm cfo:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 0.0:f64[] cfp:f64[1,1] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 1) sharding=None ] ces cfq:f64[4000,1,1] = transpose[permutation=(2, 0, 1)] cer cfr:f64[1,1,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 1, 1) sharding=None ] cfp cfs:f64[4000,1,1] = sub cfq cfr cft:f64[4000,1,1,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1, 2) shape=(4000, 1, 1, 1) sharding=None ] cfs cfu:f64[1,1,1] = pjit[ name=tril jaxpr={ lambda ; cfn:f64[1,1,1]. let cfv:i32[1,1] = iota[ dimension=0 dtype=int32 shape=(1, 1) sharding=None ] cfw:i32[1,1] = add cfv 0:i32[] cfx:i32[1,1] = iota[ dimension=1 dtype=int32 shape=(1, 1) sharding=None ] cfy:bool[1,1] = ge cfw cfx cfz:bool[1,1,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 1, 1) sharding=None ] cfy cga:f64[1,1,1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1, 1, 1) sharding=None ] 0.0:f64[] cfu:f64[1,1,1] = select_n cfz cga cfn in (cfu,) } ] cfn cgb:f64[1,1,1,4000] = pjit[ name=_solve_triangular jaxpr={ lambda ; cfu:f64[1,1,1] cft:f64[4000,1,1,1]. let cgc:f64[1,1,1,4000] = transpose[permutation=(1, 2, 3, 0)] cft cgd:f64[1,1,4000] = reshape[ dimensions=None new_sizes=(1, 1, 4000) sharding=None ] cgc cge:f64[1,1,4000] = triangular_solve[ conjugate_a=False left_side=True lower=True transpose_a=False unit_diagonal=False ] cfu cgd cgb:f64[1,1,1,4000] = reshape[ dimensions=None new_sizes=(1, 1, 1, 4000) sharding=None ] cge in (cgb,) } ] cfu cft cgf:f64[4000,1,1,1] = transpose[permutation=(3, 0, 1, 2)] cgb cgg:f64[4000,1,1] = squeeze[dimensions=(3,)] cgf cgh:f64[4000,1,1] = transpose[permutation=(0, 2, 1)] cgg cgi:f64[4000,1,1] = div cgh 1.0:f64[] cgj:f64[1] = div cfo 1.0:f64[] cgk:f64[1,1] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 1) sharding=None ] cgj cgl:f64[1,1,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 1, 1) sharding=None ] cgk cgm:f64[4000,1,1] = sub cgi cgl cgn:f64[4000,1,1] = square cgm cgo:f64[4000,1,1] = mul -0.5:f64[] cgn cgp:f64[] = log 1.0:f64[] cgq:f64[] = add 0.9189385332046727:f64[] cgp cgr:f64[4000,1,1] = sub cgo cgq cgs:f64[4000,1] = reduce_sum[axes=(1,)] cgr cgt:f64[1,1] = pjit[ name=diagonal jaxpr={ lambda ; cfn:f64[1,1,1]. let cgu:i64[1] = iota[dimension=0 dtype=int64 shape=(1,) sharding=None] cgv:i64[1] = iota[dimension=0 dtype=int64 shape=(1,) sharding=None] cgw:bool[1] = lt cgu 0:i64[] cgx:i64[1] = add cgu 1:i64[] cgy:i64[1] = select_n cgw cgu cgx cgz:bool[1] = lt cgv 0:i64[] cha:i64[1] = add cgv 1:i64[] chb:i64[1] = select_n cgz cgv cha chc:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] cgy chd:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] chb che:i32[1,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(1, 1) sharding=None ] chc chf:i32[1,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(1, 1) sharding=None ] chd chg:i32[1,2] = concatenate[dimension=1] che chf cgt:f64[1,1] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(1, 2), start_index_map=(1, 2), operand_batching_dims=(), start_indices_batching_dims=()) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(1, 1, 1) unique_indices=False ] cfn chg in (cgt,) } ] cfn chh:f64[1,1] = abs cgt chi:f64[1,1] = log chh chj:f64[1] = reduce_sum[axes=(1,)] chi chk:f64[1] = neg chj chl:f64[1] = neg chk chm:f64[1] = mul 1.0:f64[] chl chn:f64[1] = reduce_sum[axes=()] chm cho:f64[1] = add 0.0:f64[] chn chp:f64[] = neg 0.0:f64[] chq:f64[] = neg chp chr:f64[1] = broadcast_in_dim[ broadcast_dimensions=() shape=(1,) sharding=None ] 1.0:f64[] chs:f64[1] = mul chr chq cht:f64[] = reduce_sum[axes=(0,)] chs chu:f64[1] = add cho cht chv:f64[1] = copy chu chw:f64[1,1] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 1) sharding=None ] chv chx:f64[4000,1] = sub cgs chw chy:f64[4000,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4000, 1) sharding=None ] ced chz:f64[4000,2] = concatenate[dimension=1] chy chx cia:f64[4000] = reduce_sum[axes=(1,)] chz cib:f64[4000,2,2,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1, 2) shape=(4000, 2, 2, 1) sharding=None ] cbv cic:f64[2,1,4000,1] = dot_general[ dimension_numbers=(([2], [2]), ([0], [1])) preferred_element_type=float64 ] dn cib cid:f64[2,1,4000,1] = slice[ limit_indices=(2, 1, 4000, 1) start_indices=(0, 0, 0, 0) strides=None ] cic cie:f64[4000,2,1,1] = transpose[permutation=(2, 0, 1, 3)] cid cif:f64[4000,2,1] = squeeze[dimensions=(3,)] cie cig:f64[1,2,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 1) sharding=None ] do cih:f64[4000,2,1] = add cig cif cii:f64[4000,2,1] = copy cih cij:f64[2,1] = copy dp cik:f64[2,1] = broadcast_in_dim[ broadcast_dimensions=() shape=(2, 1) sharding=None ] 0.0:f64[] cil:f64[2,1] = custom_jvp_call[ name=_maximum_ call_jaxpr={ lambda ; cim:f64[2,1] cin:f64[2,1]. let cio:f64[2,1] = max cim cin in (cio,) } jvp=_maximum_jvp symbolic_zeros=False ] cij cik cip:bool[2,1] = eq cil 0.0:f64[] ciq:f64[1,2,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 1) sharding=None ] cil cir:f64[4000,2,1] = mul cii ciq cis:f64[4000,2,1] = pjit[ name=_where jaxpr={ lambda ; cip:bool[2,1] cit:f64[] cir:f64[4000,2,1]. let ciu:f64[2,1] = broadcast_in_dim[ broadcast_dimensions=() shape=(2, 1) sharding=None ] cit civ:bool[4000,2,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(4000, 2, 1) sharding=None ] cip ciw:f64[4000,2,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(4000, 2, 1) sharding=None ] ciu cis:f64[4000,2,1] = select_n civ cir ciw in (cis,) } ] cip 0.0:f64[] cir cix:f64[2,1] = add 1.0:f64[] cil ciy:f64[2,1] = lgamma cix ciz:f64[1,2,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 1) sharding=None ] ciy cja:f64[4000,2,1] = sub cis ciz cjb:bool[2,1] = eq dp cil cjc:f64[4000,2,1] = pjit[ name=_where jaxpr={ lambda ; cjb:bool[2,1] cja:f64[4000,2,1] cjd:f64[]. let cje:f64[2,1] = broadcast_in_dim[ broadcast_dimensions=() shape=(2, 1) sharding=None ] cjd cjf:bool[4000,2,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(4000, 2, 1) sharding=None ] cjb cjg:f64[4000,2,1] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(4000, 2, 1) sharding=None ] cje cjc:f64[4000,2,1] = select_n cjf cjg cja in (cjc,) } ] cjb cja -inf:f64[] cjh:f64[4000,2,1] = exp cii cji:f64[4000,2,1] = sub cjc cjh cjj:f64[4000,2] = reduce_sum[axes=(2,)] cji cjk:f64[4000] = reduce_sum[axes=(1,)] cjj cjl:f64[4000] = add cia cjk cjm:f64[1,2,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 2) sharding=None ] dr cjn:f64[4000,2,2] = sub cbv cjm cjo:f64[4000,1,2] = slice[ limit_indices=(4000, 2, 2) start_indices=(0, 1, 0) strides=None ] cjn cjp:f64[4000,1,2] = slice[ limit_indices=(4000, 1, 2) start_indices=(0, 0, 0) strides=None ] cjn cjq:f64[4000,1,2] = pjit[ name=_solve_triangular jaxpr={ lambda ; dt:f64[1,2,2] cjp:f64[4000,1,2]. let cjr:f64[4000,1,2,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1, 2) shape=(4000, 1, 2, 1) sharding=None ] cjp cjs:f64[1,2,1,4000] = transpose[permutation=(1, 2, 3, 0)] cjr cjt:f64[1,2,4000] = reshape[ dimensions=None new_sizes=(1, 2, 4000) sharding=None ] cjs cju:f64[1,2,4000] = triangular_solve[ conjugate_a=False left_side=True lower=True transpose_a=False unit_diagonal=False ] dt cjt cjv:f64[1,2,1,4000] = reshape[ dimensions=None new_sizes=(1, 2, 1, 4000) sharding=None ] cju cjw:f64[1,2,1,4000] = slice[ limit_indices=(1, 2, 1, 4000) start_indices=(0, 0, 0, 0) strides=None ] cjv cjx:f64[4000,1,2,1] = transpose[permutation=(3, 0, 1, 2)] cjw cjq:f64[4000,1,2] = squeeze[dimensions=(3,)] cjx in (cjq,) } ] dt cjp cjy:f64[1,2,4000] = dot_general[ dimension_numbers=(([2], [2]), ([0], [1])) preferred_element_type=float64 ] du cjq cjz:f64[4000,1,2] = transpose[permutation=(2, 0, 1)] cjy cka:f64[4000,1,2] = sub cjo cjz ckb:f64[4000,1,2] = slice[ limit_indices=(4000, 1, 2) start_indices=(0, 0, 0) strides=None ] cjn ckc:f64[4000,2,2] = concatenate[dimension=1] ckb cka ckd:f64[4000,2,2] = pjit[ name=_solve_triangular jaxpr={ lambda ; ds:f64[2,2,2] ckc:f64[4000,2,2]. let cke:f64[4000,2,2,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1, 2) shape=(4000, 2, 2, 1) sharding=None ] ckc ckf:f64[2,2,1,4000] = transpose[permutation=(1, 2, 3, 0)] cke ckg:f64[2,2,4000] = reshape[ dimensions=None new_sizes=(2, 2, 4000) sharding=None ] ckf ckh:f64[2,2,4000] = triangular_solve[ conjugate_a=False left_side=True lower=True transpose_a=False unit_diagonal=False ] ds ckg cki:f64[2,2,1,4000] = reshape[ dimensions=None new_sizes=(2, 2, 1, 4000) sharding=None ] ckh ckj:f64[2,2,1,4000] = slice[ limit_indices=(2, 2, 1, 4000) start_indices=(0, 0, 0, 0) strides=None ] cki ckk:f64[4000,2,2,1] = transpose[permutation=(3, 0, 1, 2)] ckj ckd:f64[4000,2,2] = squeeze[dimensions=(3,)] ckk in (ckd,) } ] ds ckc ckl:f64[2] = broadcast_in_dim[ broadcast_dimensions=() shape=(2,) sharding=None ] 0.0:f64[] ckm:i64[2,2] = iota[dimension=0 dtype=int64 shape=(2, 2) sharding=None] ckn:i64[2,2] = iota[dimension=1 dtype=int64 shape=(2, 2) sharding=None] cko:i64[2,2] = add ckm 0:i64[] ckp:bool[2,2] = eq cko ckn ckq:f64[2,2] = convert_element_type[new_dtype=float64 weak_type=False] ckp ckr:f64[2,2] = pjit[ name=cholesky jaxpr={ lambda ; ckq:f64[2,2]. let cks:f64[2,2] = transpose[permutation=(1, 0)] ckq ckt:f64[2,2] = add ckq cks cku:f64[2,2] = div ckt 2.0:f64[] ckv:f64[2,2] = cholesky cku ckw:i32[2,2] = iota[ dimension=0 dtype=int32 shape=(2, 2) sharding=None ] ckx:i32[2,2] = add ckw 0:i32[] cky:i32[2,2] = iota[ dimension=1 dtype=int32 shape=(2, 2) sharding=None ] ckz:bool[2,2] = ge ckx cky cla:f64[2,2] = broadcast_in_dim[ broadcast_dimensions=() shape=(2, 2) sharding=None ] 0.0:f64[] ckr:f64[2,2] = select_n ckz cla ckv in (ckr,) } ] ckq clb:f64[1,2] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 2) sharding=None ] ckl clc:f64[1,1,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 1, 2) sharding=None ] clb cld:f64[4000,2,2] = sub ckd clc cle:f64[4000,2,2,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1, 2) shape=(4000, 2, 2, 1) sharding=None ] cld clf:f64[2,2] = pjit[name=tril jaxpr=tril] ckr clg:f64[4000,2,1,2] = transpose[permutation=(0, 2, 3, 1)] cle clh:f64[4000,2,2] = reshape[ dimensions=None new_sizes=(4000, 2, 2) sharding=None ] clg cli:f64[2,2,4000] = pjit[ name=_solve_triangular jaxpr={ lambda ; clf:f64[2,2] clh:f64[4000,2,2]. let clj:f64[2,2,4000] = transpose[permutation=(1, 2, 0)] clh clk:f64[2,8000] = reshape[ dimensions=None new_sizes=(2, 8000) sharding=None ] clj cll:f64[2,8000] = triangular_solve[ conjugate_a=False left_side=True lower=True transpose_a=False unit_diagonal=False ] clf clk cli:f64[2,2,4000] = reshape[ dimensions=None new_sizes=(2, 2, 4000) sharding=None ] cll in (cli,) } ] clf clh clm:f64[4000,2,2] = transpose[permutation=(2, 0, 1)] cli cln:f64[4000,2,1,2] = reshape[ dimensions=None new_sizes=(4000, 2, 1, 2) sharding=None ] clm clo:f64[4000,2,2,1] = transpose[permutation=(0, 3, 1, 2)] cln clp:f64[4000,2,2] = squeeze[dimensions=(3,)] clo clq:f64[4000,2,2] = div clp 1.0:f64[] clr:f64[] = div 0.0:f64[] 1.0:f64[] cls:f64[4000,2,2] = sub clq clr clt:f64[4000,2,2] = square cls clu:f64[4000,2,2] = mul -0.5:f64[] clt clv:f64[] = log 1.0:f64[] clw:f64[] = add 0.9189385332046727:f64[] clv clx:f64[4000,2,2] = sub clu clw cly:f64[4000,2] = reduce_sum[axes=(2,)] clx clz:f64[2] = pjit[name=diagonal jaxpr=diagonal] ckr cma:f64[2] = abs clz cmb:f64[2] = log cma cmc:f64[] = reduce_sum[axes=(0,)] cmb cmd:f64[] = neg cmc cme:f64[] = neg cmd cmf:f64[] = mul 1.0:f64[] cme cmg:f64[] = reduce_sum[axes=()] cmf cmh:f64[] = add 0.0:f64[] cmg cmi:f64[] = neg 0.0:f64[] cmj:f64[] = neg cmi cmk:f64[2] = broadcast_in_dim[ broadcast_dimensions=() shape=(2,) sharding=None ] 1.0:f64[] cml:f64[2] = mul cmk cmj cmm:f64[] = reduce_sum[axes=(0,)] cml cmn:f64[] = add cmh cmm cmo:f64[] = copy cmn cmp:f64[4000,2] = sub cly cmo cmq:f64[4000] = reduce_sum[axes=(1,)] cmp cmr:f64[2,2] = pjit[ name=_diag jaxpr={ lambda ; ds:f64[2,2,2]. let cmr:f64[2,2] = pjit[ name=diagonal jaxpr={ lambda ; ds:f64[2,2,2]. let cms:i64[2,2] = iota[ dimension=0 dtype=int64 shape=(2, 2) sharding=None ] cmt:i64[2,2] = iota[ dimension=1 dtype=int64 shape=(2, 2) sharding=None ] cmu:i64[2,2] = add cms 0:i64[] cmv:bool[2,2] = eq cmu cmt cmw:f64[2,2] = convert_element_type[ new_dtype=float64 weak_type=False ] cmv cmx:i32[] = platform_index[platforms=(('mosaic',), None)] cmr:f64[2,2] = cond[ branches=( { lambda ; cmy:f64[2,2] cmz:f64[2,2,2]. let cna:f64[1,2,2] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 2, 2) sharding=None ] cmy cnb:f64[2,2,2] = mul cna cmz cnc:f64[2,2] = reduce_sum[axes=(1,)] cnb in (cnc,) } { lambda ; cnd:f64[2,2] cne:f64[2,2,2]. let cnf:i64[2] = iota[ dimension=0 dtype=int64 shape=(2,) sharding=None ] cng:i64[2] = iota[ dimension=0 dtype=int64 shape=(2,) sharding=None ] cnh:bool[2] = lt cnf 0:i64[] cni:i64[2] = add cnf 2:i64[] cnj:i64[2] = select_n cnh cnf cni cnk:bool[2] = lt cng 0:i64[] cnl:i64[2] = add cng 2:i64[] cnm:i64[2] = select_n cnk cng cnl cnn:i32[2] = convert_element_type[ new_dtype=int32 weak_type=False ] cnj cno:i32[2] = convert_element_type[ new_dtype=int32 weak_type=False ] cnm cnp:i32[2,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2, 1) sharding=None ] cnn cnq:i32[2,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(2, 1) sharding=None ] cno cnr:i32[2,2] = concatenate[dimension=1] cnp cnq cns:f64[2,2] = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(1, 2), start_index_map=(1, 2), operand_batching_dims=(), start_indices_batching_dims=()) fill_value=None indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS slice_sizes=(2, 1, 1) unique_indices=False ] cne cnr in (cns,) } ) branches_platforms=(('mosaic',), None) ] cmx cmw ds in (cmr,) } ] ds in (cmr,) } ] ds cnt:f64[2,2] = log cmr cnu:f64[] = reduce_sum[axes=(0, 1)] cnt cnv:f64[4000] = sub cmq cnu cnw:f64[4000] = sub cjl cnv cnx:f64[] = reduce_max[axes=(0,)] cnw cny:f64[4000] = sub cnw cnx cnz:f64[4000] = exp cny coa:f64[] = reduce_sum[axes=(0,)] cnz cob:f64[4000] = div cnz coa coc:f64[4000,1,1] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4000, 1, 1) sharding=None ] cob cod:f64[4000,2,2] = mul coc cbv coe:f64[2,2] = reduce_sum[axes=(0,)] cod cof:f64[4000,1,2] = slice[ limit_indices=(4000, 1, 2) start_indices=(0, 0, 0) strides=None ] cbv cog:f64[4000,1,2] = slice[ limit_indices=(4000, 2, 2) start_indices=(0, 1, 0) strides=None ] cbv coh:f64[4000,1,4] = concatenate[dimension=2] cof cog coi:f64[1,4,4] = pjit[ name=cov jaxpr={ lambda ; coh:f64[4000,1,4] cob:f64[4000]. let coj:f64[4000,1,4] = pjit[ name=atleast_2d jaxpr={ lambda ; coh:f64[4000,1,4]. let in (coh,) } ] coh cok:f64[1,4,4000] = transpose[permutation=(1, 2, 0)] coj col:f64[4000] = abs cob com:f64[1,4000] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 4000) sharding=None ] col con:f64[1] = reduce_sum[axes=(1,)] com coo:f64[1,1,4000] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 1, 4000) sharding=None ] com cop:f64[1,4,4000] = mul cok coo coq:f64[1,4] = reduce_sum[axes=(2,)] cop cor:f64[1,1] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 1) sharding=None ] con cos:f64[1,4] = div coq cor cot:f64[4] = broadcast_in_dim[ broadcast_dimensions=(0,) shape=(4,) sharding=None ] con cou:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] cot cov:f64[] = squeeze[dimensions=(0,)] cou cow:f64[4000] = mul col col cox:f64[] = reduce_sum[axes=(0,)] cow coy:f64[] = mul 1.0:f64[] cox coz:f64[] = div coy cov cpa:f64[] = sub cov coz cpb:f64[1,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1) shape=(1, 4, 1) sharding=None ] cos cpc:f64[1,4,4000] = sub cok cpb cpd:f64[1,4000] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 4000) sharding=None ] col cpe:f64[1,1,4000] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 1, 4000) sharding=None ] cpd cpf:f64[1,4,4000] = mul cpc cpe cpg:f64[1,4000,4] = transpose[permutation=(0, 2, 1)] cpf cph:f64[1,4,4] = dot_general[ dimension_numbers=(([2], [1]), ([0], [0])) preferred_element_type=float64 ] cpc cpg coi:f64[1,4,4] = div cph cpa in (coi,) } ] coh cob cpi:f64[2,2] = broadcast_in_dim[ broadcast_dimensions=() shape=(2, 2) sharding=None ] 0.0:f64[] cpj:f64[1,4,4] = slice[ limit_indices=(1, 4, 4) start_indices=(0, 0, 0) strides=None ] coi cpk:f64[4,4] = squeeze[dimensions=(0,)] cpj cpl:f64[2,2] = slice[limit_indices=(2, 2) start_indices=(0, 0) strides=None] cpk cpm:f64[4,4] = pjit[ name=block_diag jaxpr={ lambda ; cpi:f64[2,2] cpl:f64[2,2]. let cpn:f64[2,2] = pjit[name=atleast_2d jaxpr=atleast_2d] cpi cpo:f64[2,2] = pjit[name=atleast_2d jaxpr=atleast_2d] cpl cpp:f64[2,4] = pad[padding_config=((0, 0, 0), (2, 0, 0))] cpo 0.0:f64[] cpq:f64[2,4] = pad[padding_config=((0, 0, 0), (0, 2, 0))] cpn 0.0:f64[] cpm:f64[4,4] = concatenate[dimension=0] cpq cpp in (cpm,) } ] cpi cpl cpr:f64[1,4,4] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(1, 4, 4) sharding=None ] cpm cps:f64[2,4,4] = concatenate[dimension=0] cpr coi cpt:f64[2,4] cpu:f64[2,4,4] = pjit[ name=eigh jaxpr={ lambda ; cps:f64[2,4,4]. let cpv:f64[2,4,4] = transpose[permutation=(0, 2, 1)] cps cpw:f64[2,4,4] = add cps cpv cpx:f64[2,4,4] = div cpw 2.0:f64[] cpu:f64[2,4,4] cpt:f64[2,4] = eigh[ lower=True sort_eigenvalues=True subset_by_index=None ] cpx in (cpt, cpu) } ] cps cpy:f64[2,4] = abs cpt cpz:f64[2,4] = sqrt cpy cqa:f64[2,4,4] = dot_general[ dimension_numbers=(([], []), ([0, 2], [0, 1])) preferred_element_type=float64 ] cpu cpz cqb:f64[2,4,4] = pjit[ name=qr jaxpr={ lambda ; cqa:f64[2,4,4]. let cqc:f64[2,4,4] cqb:f64[2,4,4] = qr[ full_matrices=True pivoting=False use_magma=None ] cqa in (cqb,) } ] cqa cqd:u32[4,4] = iota[dimension=0 dtype=uint32 shape=(4, 4) sharding=None] cqe:u32[4,4] = iota[dimension=1 dtype=uint32 shape=(4, 4) sharding=None] cqf:bool[4,4] = eq cqd cqe cqg:bool[2,4,4] = broadcast_in_dim[ broadcast_dimensions=(1, 2) shape=(2, 4, 4) sharding=None ] cqf cqh:f64[2,4,4] = broadcast_in_dim[ broadcast_dimensions=() shape=(2, 4, 4) sharding=None ] 0.0:f64[] cqi:f64[2,4,4] = select_n cqg cqh cqb cqj:f64[2,4] = reduce_sum[axes=(1,)] cqi cqk:f64[2,4,1] = broadcast_in_dim[ broadcast_dimensions=(0, 1) shape=(2, 4, 1) sharding=None ] cqj cql:f64[2,4,1] = sign cqk cqm:f64[2,4,4] = mul cqb cql cqn:f64[2,4,4] = transpose[permutation=(0, 2, 1)] cqm cqo:f64[1,2,2] = slice[ limit_indices=(2, 2, 2) start_indices=(1, 0, 0) strides=None ] cqn cqp:f64[1,2,2] = slice[ limit_indices=(2, 4, 2) start_indices=(1, 2, 0) strides=None ] cqn cqq:f64[2,2,2] = slice[ limit_indices=(2, 4, 4) start_indices=(0, 2, 2) strides=None ] cqn in (dw, coe, cqq, cqo, cqp, cnw) }, ()) During handling of the above exception, another exception occurred: JaxStackTraceBeforeTransformation Traceback (most recent call last) File ~/.local/share/uv/python/cpython-3.10.17-macos-aarch64-none/lib/python3.10/runpy.py:196, in _run_module_as_main() 195 sys.argv[0] = mod_spec.origin --> 196 return _run_code(code, main_globals, None, 197 "__main__", mod_spec) File ~/.local/share/uv/python/cpython-3.10.17-macos-aarch64-none/lib/python3.10/runpy.py:86, in _run_code() 79 run_globals.update(__name__ = mod_name, 80 __file__ = fname, 81 __cached__ = cached, (...) 84 __package__ = pkg_name, 85 __spec__ = mod_spec) ---> 86 exec(code, run_globals) 87 return run_globals File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel_launcher.py:18 16 from ipykernel import kernelapp as app ---> 18 app.launch_new_instance() File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/traitlets/config/application.py:1075, in launch_instance() 1074 app.initialize(argv) -> 1075 app.start() File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/kernelapp.py:739, in start() 738 try: --> 739 self.io_loop.start() 740 except KeyboardInterrupt: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tornado/platform/asyncio.py:211, in start() 210 def start(self) -> None: --> 211 self.asyncio_loop.run_forever() File ~/.local/share/uv/python/cpython-3.10.17-macos-aarch64-none/lib/python3.10/asyncio/base_events.py:603, in run_forever() 602 while True: --> 603 self._run_once() 604 if self._stopping: File ~/.local/share/uv/python/cpython-3.10.17-macos-aarch64-none/lib/python3.10/asyncio/base_events.py:1909, in _run_once() 1908 else: -> 1909 handle._run() 1910 handle = None File ~/.local/share/uv/python/cpython-3.10.17-macos-aarch64-none/lib/python3.10/asyncio/events.py:80, in _run() 79 try: ---> 80 self._context.run(self._callback, *self._args) 81 except (SystemExit, KeyboardInterrupt): File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:545, in dispatch_queue() 544 try: --> 545 await self.process_one() 546 except Exception: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:534, in process_one() 533 return --> 534 await dispatch(*args) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell() 436 if inspect.isawaitable(result): --> 437 await result 438 except Exception: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/ipkernel.py:362, in execute_request() 361 self._associate_new_top_level_threads_with(parent_header) --> 362 await super().execute_request(stream, ident, parent) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:778, in execute_request() 777 if inspect.isawaitable(reply_content): --> 778 reply_content = await reply_content 780 # Flush output before sending the reply. File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/ipkernel.py:449, in do_execute() 448 if accepts_params["cell_id"]: --> 449 res = shell.run_cell( 450 code, 451 store_history=store_history, 452 silent=silent, 453 cell_id=cell_id, 454 ) 455 else: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/zmqshell.py:549, in run_cell() 548 self._last_traceback = None --> 549 return super().run_cell(*args, **kwargs) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3077, in run_cell() 3076 try: -> 3077 result = self._run_cell( 3078 raw_cell, store_history, silent, shell_futures, cell_id 3079 ) 3080 finally: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3132, in _run_cell() 3131 try: -> 3132 result = runner(coro) 3133 except BaseException as e: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner() 127 try: --> 128 coro.send(None) 129 except StopIteration as exc: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3336, in run_cell_async() 3333 interactivity = "none" if silent else self.ast_node_interactivity -> 3336 has_raised = await self.run_ast_nodes(code_ast.body, cell_name, 3337 interactivity=interactivity, compiler=compiler, result=result) 3339 self.last_execution_succeeded = not has_raised File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3519, in run_ast_nodes() 3518 asy = compare(code) -> 3519 if await self.run_code(code, result, async_=asy): 3520 return True File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3579, in run_code() 3578 else: -> 3579 exec(code_obj, self.user_global_ns, self.user_ns) 3580 finally: 3581 # Reset our crash handler in place Cell In[41], line 2 1 results_are = pd.DataFrame( ----> 2 [ 3 asymptotic_variance(n, k) 4 for n, k in tqdm(zip(ns_are, keys_are), total=len(ns_are)) 5 ] 6 ) 8 results_are.to_csv(here("data/figures/are_meis_cem_ssms.csv"), index=False) Cell In[41], line 3, in <listcomp>() 1 results_are = pd.DataFrame( 2 [ ----> 3 asymptotic_variance(n, k) 4 for n, k in tqdm(zip(ns_are, keys_are), total=len(ns_are)) 5 ] 6 ) 8 results_are.to_csv(here("data/figures/are_meis_cem_ssms.csv"), index=False) Cell In[27], line 37, in asymptotic_variance() 35 sks_cem = sks[M:] ---> 37 logdet_cem = asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M=len(sks_cem)) 38 logdet_meis = asymptotic_det_meis( 39 Y, pgssm, prop_la, N_iter, N_samples, key, M=len(sks_meis) 40 ) Cell In[27], line 16, in asymptotic_det_cem() 15 key, *subkeys = jrn.split(key, 1 + M) ---> 16 proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys] 17 modes = jnp.array([proposal.mean[:, 0] for proposal in proposals]) Cell In[27], line 16, in <listcomp>() 15 key, *subkeys = jrn.split(key, 1 + M) ---> 16 proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys] 17 modes = jnp.array([proposal.mean[:, 0] for proposal in proposals]) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/ce_method.py:222, in cross_entropy_method() 220 return new_proposal, log_w --> 222 final_proposal, log_w = fori_loop( 223 0, n_iter, _iteration, (initial, jnp.empty(4 * N)) 224 ) 226 return final_proposal, log_w File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/ce_method.py:202, in _iteration() 200 model_log_weights = partial(log_weight_cem, y=y, model=model, proposal=proposal) --> 202 samples = simulate_cem(proposal, N, subkey_crn) 204 _N, np1, m = samples.shape File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/ce_method.py:103, in simulate_cem() 102 l_samples = location_antithetic(samples, mean) --> 103 s_samples = scale_antithethic(u, samples, mean) 104 ls_samples = scale_antithethic(u, l_samples, mean) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/util.py:87, in scale_antithethic() 86 c = jnp.linalg.norm(u, axis=1) ** 2 ---> 87 c_prime = chi_dist.quantile(1.0 - chi_dist.cdf(c)) 89 return mean[None] + jnp.sqrt(c_prime / c)[:, None, None] * (samples - mean[None]) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1573, in quantile() 1556 """Quantile function. Aka 'inverse cdf' or 'percent point function'. 1557 1558 Given random variable `X` and `p in [0, 1]`, the `quantile` is: (...) 1571 values of type `self.dtype`. 1572 """ -> 1573 return self._call_quantile(value, name, **kwargs) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1553, in _call_quantile() 1547 value = distribution_util.with_dependencies([ 1548 assert_util.assert_less_equal(value, tf.cast(1, value.dtype), 1549 message='`value` must be <= 1'), 1550 assert_util.assert_greater_equal(value, tf.cast(0, value.dtype), 1551 message='`value` must be >= 0') 1552 ], value) -> 1553 return self._quantile(value, **kwargs) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/chi2.py:139, in _quantile() 138 def _quantile(self, p): --> 139 return 2. * special.igammainv(0.5 * self.df, p) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/math/special.py:1573, in igammainv() 1572 p = tf.convert_to_tensor(p, dtype=dtype) -> 1573 return _igammainv_custom_gradient(a, p) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/math/special.py:1546, in _igammainv_custom_gradient() 1541 @tfp_custom_gradient.custom_gradient( 1542 vjp_fwd=_igammainv_fwd, 1543 vjp_bwd=_igammainv_bwd, 1544 jvp_fn=_igammainv_jvp) 1545 def _igammainv_custom_gradient(a, p): -> 1546 return _shared_igammainv_computation(a, p, is_igammainv=True) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/math/special.py:1470, in _shared_igammainv_computation() 1466 factorial = tf.math.exp(a * tf.math.log(x) - x - tf.math.lgamma(a)) 1468 f_over_der = tf.where( 1469 ((p <= 0.9) & is_igammainv) | ((q > 0.9) & (not is_igammainv)), -> 1470 (tf.math.igamma(a, x) - p) * x / factorial, 1471 -(tf.math.igammac(a, x) - q) * x / factorial) 1472 second_der_over_der = -1. + (a - 1.) / x File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_math.py:656, in <lambda>() 650 greater_equal = utils.copy_docstring( 651 'tf.math.greater_equal', 652 lambda x, y, name=None: np.greater_equal(x, y)) 654 igamma = utils.copy_docstring( 655 'tf.math.igamma', --> 656 lambda a, x, name=None: scipy_special.gammainc(a, x)) 658 igammac = utils.copy_docstring( 659 'tf.math.igammac', 660 lambda a, x, name=None: scipy_special.gammaincc(a, x)) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/scipy/special.py:320, in gammainc() 319 a, x = promote_args_inexact("gammainc", a, x) --> 320 return lax.igamma(a, x) JaxStackTraceBeforeTransformation: KeyboardInterrupt The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. -------------------- The above exception was the direct cause of the following exception: KeyboardInterrupt Traceback (most recent call last) Cell In[41], line 2 1 results_are = pd.DataFrame( ----> 2 [ 3 asymptotic_variance(n, k) 4 for n, k in tqdm(zip(ns_are, keys_are), total=len(ns_are)) 5 ] 6 ) 8 results_are.to_csv(here("data/figures/are_meis_cem_ssms.csv"), index=False) 9 results_are Cell In[41], line 3, in <listcomp>(.0) 1 results_are = pd.DataFrame( 2 [ ----> 3 asymptotic_variance(n, k) 4 for n, k in tqdm(zip(ns_are, keys_are), total=len(ns_are)) 5 ] 6 ) 8 results_are.to_csv(here("data/figures/are_meis_cem_ssms.csv"), index=False) 9 results_are Cell In[27], line 37, in asymptotic_variance(n, key) 34 sks_meis = sks[:M] 35 sks_cem = sks[M:] ---> 37 logdet_cem = asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M=len(sks_cem)) 38 logdet_meis = asymptotic_det_meis( 39 Y, pgssm, prop_la, N_iter, N_samples, key, M=len(sks_meis) 40 ) 42 result = pd.Series( 43 { 44 "n": n, (...) 50 } 51 ) Cell In[27], line 16, in asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M) 14 def asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M: int): 15 key, *subkeys = jrn.split(key, 1 + M) ---> 16 proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys] 17 modes = jnp.array([proposal.mean[:, 0] for proposal in proposals]) 18 cov = jnp.cov(modes, rowvar=False) * N_samples Cell In[27], line 16, in <listcomp>(.0) 14 def asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M: int): 15 key, *subkeys = jrn.split(key, 1 + M) ---> 16 proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys] 17 modes = jnp.array([proposal.mean[:, 0] for proposal in proposals]) 18 cov = jnp.cov(modes, rowvar=False) * N_samples File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/ce_method.py:222, in cross_entropy_method(model, y, N, key, n_iter) 218 new_proposal = proposal_from_moments(mean, consecutive_covs) 220 return new_proposal, log_w --> 222 final_proposal, log_w = fori_loop( 223 0, n_iter, _iteration, (initial, jnp.empty(4 * N)) 224 ) 226 return final_proposal, log_w [... skipping hidden 1 frame] File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:2386, in fori_loop(lower, upper, body_fun, init_val, unroll) 2384 scan_body = _fori_scan_body_fun(body_fun) 2385 api_util.save_wrapped_fun_sourceinfo(scan_body, body_fun) -> 2386 (_, result), _ = scan( 2387 scan_body, 2388 (lower_, init_val), 2389 None, 2390 length=length, 2391 unroll=unroll, 2392 ) 2393 return result 2394 if unroll is not None and unroll is not False and unroll != 1: [... skipping hidden 1 frame] File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:355, in scan(f, init, xs, length, reverse, unroll, _split_transpose) 352 consts = [*new_consts, *consts] 353 num_carry -= len(new_consts) --> 355 out = scan_p.bind(*consts, *in_flat, 356 reverse=reverse, length=length, jaxpr=jaxpr, 357 num_consts=len(consts), num_carry=num_carry, 358 linear=(False,) * (len(consts) + len(in_flat)), 359 unroll=unroll, _split_transpose=_split_transpose) 361 if any(move_to_const): 362 out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params) 529 def bind(self, *args, **params): 530 args = args if self.skip_canonicalization else map(canonicalize_value, args) --> 531 return self._true_bind(*args, **params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params) 549 trace_ctx.set_trace(eval_trace) 550 try: --> 551 return self.bind_with_trace(prev_trace, args, params) 552 finally: 553 trace_ctx.set_trace(prev_trace) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params) 555 def bind_with_trace(self, trace, args, params): --> 556 return trace.process_primitive(self, args, params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:1060, in EvalTrace.process_primitive(self, primitive, args, params) 1058 args = map(full_lower, args) 1059 check_eval_args(args) -> 1060 return primitive.impl(*args, **params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/dispatch.py:88, in apply_primitive(prim, *args, **params) 86 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) 87 try: ---> 88 outs = fun(*args) 89 finally: 90 lib.jax_jit.swap_thread_local_state_disable_jit(prev) [... skipping hidden 1 frame] File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:334, in _cpp_pjit.<locals>.cache_miss(*args, **kwargs) 329 if config.no_tracing.value: 330 raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " 331 "`jit`, but 'no_tracing' is set") 333 (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, box_data, --> 334 executable, pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs) 336 maybe_fastpath_data = _get_fastpath_data( 337 executable, out_tree, args_flat, out_flat, attrs_tracked, box_data, 338 jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, pgle_profiler) 340 return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:195, in _python_pjit_helper(fun, jit_info, *args, **kwargs) 193 args_flat = map(core.full_lower, args_flat) 194 core.check_eval_args(args_flat) --> 195 out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) 196 else: 197 out_flat = pjit_p.bind(*args_flat, **p.params) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1853, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args) 1850 compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items()) 1851 # Passing mutable PGLE profile here since it should be extracted by JAXPR to 1852 # initialize the fdo_profile compile option. -> 1853 compiled = _resolve_and_lower( 1854 args, jaxpr=jaxpr, in_shardings=in_shardings, 1855 out_shardings=out_shardings, in_layouts=in_layouts, 1856 out_layouts=out_layouts, donated_invars=donated_invars, 1857 ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, 1858 inline=inline, lowering_platforms=None, 1859 lowering_parameters=mlir.LoweringParameters(), 1860 pgle_profiler=pgle_profiler, 1861 compiler_options_kvs=compiler_options_kvs, 1862 ).compile() 1864 # This check is expensive so only do it if enable_checks is on. 1865 if compiled._auto_spmd_lowering and config.enable_checks.value: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1820, in _resolve_and_lower(args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, lowering_platforms, lowering_parameters, pgle_profiler, compiler_options_kvs) 1817 in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, 1818 jaxpr.in_avals) 1819 out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals) -> 1820 return _pjit_lower( 1821 jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, 1822 donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, 1823 lowering_platforms=lowering_platforms, 1824 lowering_parameters=lowering_parameters, 1825 pgle_profiler=pgle_profiler) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1953, in _pjit_lower(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, lowering_platforms, lowering_parameters, pgle_profiler) 1936 def _pjit_lower( 1937 jaxpr: core.ClosedJaxpr, 1938 in_shardings, (...) 1950 lowering_parameters: mlir.LoweringParameters, 1951 pgle_profiler: profiler.PGLEProfiler | None): 1952 util.test_event("pjit_lower") -> 1953 return pxla.lower_sharding_computation( 1954 jaxpr, 'jit', name, in_shardings, out_shardings, 1955 in_layouts, out_layouts, tuple(donated_invars), 1956 keep_unused=keep_unused, context_mesh=ctx_mesh, 1957 compiler_options_kvs=compiler_options_kvs, 1958 lowering_platforms=lowering_platforms, 1959 lowering_parameters=lowering_parameters, 1960 pgle_profiler=pgle_profiler) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/profiler.py:354, in annotate_function.<locals>.wrapper(*args, **kwargs) 351 @wraps(func) 352 def wrapper(*args, **kwargs): 353 with TraceAnnotation(name, **decorator_kwargs): --> 354 return func(*args, **kwargs) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:2378, in lower_sharding_computation(closed_jaxpr, api_name, fun_name, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, keep_unused, context_mesh, compiler_options_kvs, lowering_platforms, lowering_parameters, pgle_profiler) 2372 semantic_in_shardings = SemanticallyEqualShardings( 2373 in_shardings, global_in_avals) 2374 semantic_out_shardings = SemanticallyEqualShardings( 2375 out_shardings, global_out_avals) 2377 (module, keepalive, host_callbacks, unordered_effects, ordered_effects, -> 2378 nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( 2379 closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, 2380 semantic_out_shardings, in_layouts, out_layouts, num_devices, 2381 tuple(da_object) if prim_requires_devices else None, # type: ignore[arg-type] 2382 donated_invars, name_stack, all_default_mem_kind, inout_aliases, 2383 propagated_out_mem_kinds, platforms, 2384 lowering_parameters=lowering_parameters, 2385 abstract_mesh=abstract_mesh) 2387 # backend and device_assignment is passed through to MeshExecutable because 2388 # if keep_unused=False and all in_shardings are pruned, then there is no way 2389 # to get the device_assignment and backend. So pass it to MeshExecutable 2390 # because we calculate the device_assignment and backend before in_shardings, 2391 # etc are pruned. 2392 return MeshComputation( 2393 str(name_stack), 2394 module, (...) 2421 intermediate_shardings=unique_intermediate_shardings, 2422 context_mesh=context_mesh) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1968, in _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, in_layouts, out_layouts, num_devices, device_assignment, donated_invars, name_stack, all_default_mem_kind, inout_aliases, propagated_out_mem_kinds, platforms, lowering_parameters, abstract_mesh) 1964 ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) 1965 with dispatch.log_elapsed_time( 1966 "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec", 1967 fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): -> 1968 lowering_result = mlir.lower_jaxpr_to_module( 1969 module_name, 1970 closed_jaxpr, 1971 ordered_effects=ordered_effects, 1972 backend=backend, 1973 platforms=platforms, 1974 axis_context=axis_ctx, 1975 name_stack=name_stack, 1976 donated_args=donated_invars, 1977 replicated_args=replicated_args, 1978 arg_shardings=in_mlir_shardings, 1979 result_shardings=out_mlir_shardings, 1980 in_layouts=in_layouts, 1981 out_layouts=out_layouts, 1982 arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)), 1983 result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)), 1984 num_replicas=nreps, 1985 num_partitions=num_partitions, 1986 all_default_mem_kind=all_default_mem_kind, 1987 input_output_aliases=inout_aliases, 1988 propagated_out_mem_kinds=propagated_out_mem_kinds, 1989 lowering_parameters=lowering_parameters) 1990 tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) 1991 unordered_effects = list( 1992 effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1271, in lower_jaxpr_to_module(***failed resolving arguments***) 1269 attrs["mhlo.num_replicas"] = i32_attr(num_replicas) 1270 attrs["mhlo.num_partitions"] = i32_attr(num_partitions) -> 1271 lower_jaxpr_to_fun( 1272 ctx, "main", jaxpr, ordered_effects, 1273 name_stack=name_stack, 1274 public=True, 1275 replicated_args=replicated_args, 1276 arg_shardings=arg_shardings, 1277 result_shardings=result_shardings, 1278 input_output_aliases=input_output_aliases, 1279 xla_donated_args=xla_donated_args, 1280 arg_names=arg_names, 1281 result_names=result_names, 1282 arg_memory_kinds=arg_memory_kinds, 1283 result_memory_kinds=result_memory_kinds, 1284 arg_layouts=in_layouts, 1285 result_layouts=out_layouts, 1286 propagated_out_mem_kinds=propagated_out_mem_kinds) 1288 try: 1289 if not ctx.module.operation.verify(): File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1763, in lower_jaxpr_to_fun(ctx, name, jaxpr, effects, name_stack, public, replicated_args, arg_shardings, result_shardings, use_sharding_annotations, input_output_aliases, xla_donated_args, api_name, arg_names, result_names, arg_memory_kinds, result_memory_kinds, arg_layouts, result_layouts, propagated_out_mem_kinds) 1761 callee_name_stack = name_stack 1762 consts = [ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts] -> 1763 out_vals, tokens_out = jaxpr_subcomp( 1764 ctx, jaxpr.jaxpr, callee_name_stack, tokens_in, 1765 consts, *args, dim_var_values=dim_var_values) 1766 outs: list[IrValues] = [] 1767 for eff in effects: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2040, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args) 2037 rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) 2039 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes) -> 2040 ans = lower_per_platform(rule_ctx, str(eqn.primitive), 2041 platform_rules, default_rule, 2042 eqn.effects, 2043 *in_nodes, **eqn.params) 2045 if effects: 2046 # If there were ordered effects in the primitive, there should be output 2047 # tokens we need for subsequent ordered effects. 2048 tokens_out = rule_ctx.tokens_out File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2162, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs) 2160 # If there is a single rule left just apply the rule, without conditionals. 2161 if len(kept_rules) == 1: -> 2162 output = kept_rules[0](ctx, *rule_args, **rule_kwargs) 2163 foreach( 2164 lambda o: wrap_compute_type_in_place(ctx, o.owner), 2165 filter(_is_not_block_argument, flatten_ir_values(output)), 2166 ) 2167 foreach( 2168 lambda o: wrap_xla_metadata_in_place(ctx, o.owner), 2169 flatten_ir_values(output), 2170 ) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2278, in lower_fun.<locals>.f_lowered(ctx, *args, **params) 2276 else: 2277 sub_context = ctx.module_context -> 2278 out, tokens = jaxpr_subcomp( 2279 sub_context, jaxpr, ctx.name_stack, ctx.tokens_in, 2280 _ir_consts(consts), *args, 2281 dim_var_values=ctx.dim_var_values) 2282 ctx.set_tokens_out(tokens) 2283 return out File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2040, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args) 2037 rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) 2039 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes) -> 2040 ans = lower_per_platform(rule_ctx, str(eqn.primitive), 2041 platform_rules, default_rule, 2042 eqn.effects, 2043 *in_nodes, **eqn.params) 2045 if effects: 2046 # If there were ordered effects in the primitive, there should be output 2047 # tokens we need for subsequent ordered effects. 2048 tokens_out = rule_ctx.tokens_out File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2162, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs) 2160 # If there is a single rule left just apply the rule, without conditionals. 2161 if len(kept_rules) == 1: -> 2162 output = kept_rules[0](ctx, *rule_args, **rule_kwargs) 2163 foreach( 2164 lambda o: wrap_compute_type_in_place(ctx, o.owner), 2165 filter(_is_not_block_argument, flatten_ir_values(output)), 2166 ) 2167 foreach( 2168 lambda o: wrap_xla_metadata_in_place(ctx, o.owner), 2169 flatten_ir_values(output), 2170 ) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:2101, in _while_lowering(ctx, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts, *args) 2098 body_name_stack = name_stack.extend('body') 2099 body_consts = [mlir.ir_constant(xla.canonicalize_dtype(x)) 2100 for x in body_jaxpr.consts] -> 2101 new_z, tokens_out = mlir.jaxpr_subcomp( 2102 ctx.module_context, body_jaxpr.jaxpr, body_name_stack, 2103 tokens_in, body_consts, *(y + z), dim_var_values=ctx.dim_var_values) 2104 out_tokens = [tokens_out.get(eff) for eff in body_effects] 2105 if batched: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2040, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args) 2037 rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) 2039 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes) -> 2040 ans = lower_per_platform(rule_ctx, str(eqn.primitive), 2041 platform_rules, default_rule, 2042 eqn.effects, 2043 *in_nodes, **eqn.params) 2045 if effects: 2046 # If there were ordered effects in the primitive, there should be output 2047 # tokens we need for subsequent ordered effects. 2048 tokens_out = rule_ctx.tokens_out File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2162, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs) 2160 # If there is a single rule left just apply the rule, without conditionals. 2161 if len(kept_rules) == 1: -> 2162 output = kept_rules[0](ctx, *rule_args, **rule_kwargs) 2163 foreach( 2164 lambda o: wrap_compute_type_in_place(ctx, o.owner), 2165 filter(_is_not_block_argument, flatten_ir_values(output)), 2166 ) 2167 foreach( 2168 lambda o: wrap_xla_metadata_in_place(ctx, o.owner), 2169 flatten_ir_values(output), 2170 ) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2383, in core_call_lowering(ctx, name, backend, call_jaxpr, *args) 2381 def core_call_lowering(ctx: LoweringRuleContext, 2382 *args, name, backend=None, call_jaxpr): -> 2383 out_nodes, tokens = call_lowering( 2384 name, ctx.name_stack, call_jaxpr, backend, ctx.module_context, 2385 ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args, 2386 dim_var_values=ctx.dim_var_values) 2387 ctx.set_tokens_out(tokens) 2388 return out_nodes File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2366, in call_lowering(***failed resolving arguments***) 2360 def call_lowering(fn_name, name_stack, call_jaxpr, backend, 2361 ctx: ModuleContext, avals_in, 2362 avals_out, tokens_in, *args, 2363 dim_var_values: Sequence[ir.Value], 2364 arg_names=None, result_names=None): 2365 del avals_in -> 2366 func_op, output_types, effects = lower_called_computation( 2367 fn_name, name_stack, call_jaxpr, ctx, avals_out, tokens_in, 2368 backend=backend, arg_names=arg_names, result_names=result_names) 2369 symbol_name = func_op.name.value 2370 flat_output_types = flatten_ir_types(output_types) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2348, in lower_called_computation(fn_name, name_stack, call_jaxpr, ctx, avals_out, tokens_in, backend, arg_names, result_names) 2346 output_types = map(aval_to_ir_type, avals_out) 2347 output_types = [token_type()] * len(effects) + output_types -> 2348 func_op = _lower_jaxpr_to_fun_cached( 2349 ctx, 2350 fn_name, 2351 call_jaxpr, 2352 effects, 2353 name_stack, 2354 arg_names=arg_names, 2355 result_names=result_names, 2356 ) 2357 return func_op, output_types, effects File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2297, in _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack, arg_names, result_names) 2295 except KeyError: 2296 num_callbacks = len(ctx.host_callbacks) -> 2297 func_op = lower_jaxpr_to_fun( 2298 ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names, 2299 result_names=result_names) 2301 # If this Jaxpr includes callbacks, we can't cache the lowering because 2302 # on TPU every callback must have a globally unique channel, but the 2303 # channel gets assigned during lowering. 2304 has_callbacks = len(ctx.host_callbacks) > num_callbacks File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1763, in lower_jaxpr_to_fun(ctx, name, jaxpr, effects, name_stack, public, replicated_args, arg_shardings, result_shardings, use_sharding_annotations, input_output_aliases, xla_donated_args, api_name, arg_names, result_names, arg_memory_kinds, result_memory_kinds, arg_layouts, result_layouts, propagated_out_mem_kinds) 1761 callee_name_stack = name_stack 1762 consts = [ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts] -> 1763 out_vals, tokens_out = jaxpr_subcomp( 1764 ctx, jaxpr.jaxpr, callee_name_stack, tokens_in, 1765 consts, *args, dim_var_values=dim_var_values) 1766 outs: list[IrValues] = [] 1767 for eff in effects: File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2040, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args) 2037 rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) 2039 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes) -> 2040 ans = lower_per_platform(rule_ctx, str(eqn.primitive), 2041 platform_rules, default_rule, 2042 eqn.effects, 2043 *in_nodes, **eqn.params) 2045 if effects: 2046 # If there were ordered effects in the primitive, there should be output 2047 # tokens we need for subsequent ordered effects. 2048 tokens_out = rule_ctx.tokens_out File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2162, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs) 2160 # If there is a single rule left just apply the rule, without conditionals. 2161 if len(kept_rules) == 1: -> 2162 output = kept_rules[0](ctx, *rule_args, **rule_kwargs) 2163 foreach( 2164 lambda o: wrap_compute_type_in_place(ctx, o.owner), 2165 filter(_is_not_block_argument, flatten_ir_values(output)), 2166 ) 2167 foreach( 2168 lambda o: wrap_xla_metadata_in_place(ctx, o.owner), 2169 flatten_ir_values(output), 2170 ) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/custom_derivatives.py:430, in _custom_jvp_vjp_call_lowering(ctx, call_jaxpr, *args, **_) 428 def _custom_jvp_vjp_call_lowering(ctx, *args, call_jaxpr, **_): 429 consts = mlir._ir_consts(call_jaxpr.consts) --> 430 out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr, 431 ctx.name_stack, ctx.tokens_in, consts, 432 *args, dim_var_values=ctx.dim_var_values) 433 ctx.set_tokens_out(tokens) 434 return out File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2040, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args) 2037 rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) 2039 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes) -> 2040 ans = lower_per_platform(rule_ctx, str(eqn.primitive), 2041 platform_rules, default_rule, 2042 eqn.effects, 2043 *in_nodes, **eqn.params) 2045 if effects: 2046 # If there were ordered effects in the primitive, there should be output 2047 # tokens we need for subsequent ordered effects. 2048 tokens_out = rule_ctx.tokens_out File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2162, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs) 2160 # If there is a single rule left just apply the rule, without conditionals. 2161 if len(kept_rules) == 1: -> 2162 output = kept_rules[0](ctx, *rule_args, **rule_kwargs) 2163 foreach( 2164 lambda o: wrap_compute_type_in_place(ctx, o.owner), 2165 filter(_is_not_block_argument, flatten_ir_values(output)), 2166 ) 2167 foreach( 2168 lambda o: wrap_xla_metadata_in_place(ctx, o.owner), 2169 flatten_ir_values(output), 2170 ) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2278, in lower_fun.<locals>.f_lowered(ctx, *args, **params) 2276 else: 2277 sub_context = ctx.module_context -> 2278 out, tokens = jaxpr_subcomp( 2279 sub_context, jaxpr, ctx.name_stack, ctx.tokens_in, 2280 _ir_consts(consts), *args, 2281 dim_var_values=ctx.dim_var_values) 2282 ctx.set_tokens_out(tokens) 2283 return out File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2006, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args) 2003 in_nodes = map(read, eqn.invars) 2004 source_info = eqn.source_info.replace( 2005 name_stack=name_stack + eqn.source_info.name_stack) -> 2006 loc = _source_info_to_location(ctx, eqn.primitive, source_info) 2007 with (source_info_util.user_context(eqn.source_info.traceback), loc, 2008 eqn.ctx.manager): 2009 override_rule = get_override_lowering_rule(eqn.primitive) File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:516, in _source_info_to_location(ctx, primitive, source_info) 512 else: 513 loc = ir.Location.file(get_canonical_source_file(frame.file_name, 514 ctx.traceback_caches), 515 frame.start_line, frame.start_column) --> 516 loc = ir.Location.name(eqn_str, childLoc=loc) 517 # TODO(phawkins): also include primitive.name as the operator type. 518 return loc KeyboardInterrupt: